diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 100df35ea..b4a84f16d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10', '3.11' ] + python-version: [ '3.9', '3.10', '3.11' ] max-parallel: 4 steps: @@ -42,6 +42,8 @@ jobs: timeout 25m python -m unittest ${second_half_nl[@]} - name: Regression run: source tests/resources/regression/run_regression.sh + - name: Model backwards compatibility + run: source tests/resources/model_compatibility/check_backwards_compatibility.sh - name: Get the latest release version id: get_latest_release uses: actions/github-script@v6 diff --git a/docs/main.md b/docs/main.md index 99817cee4..f80b12b1a 100644 --- a/docs/main.md +++ b/docs/main.md @@ -122,12 +122,12 @@ If you have access to UMLS or SNOMED-CT, you can download the pre-built CDB and A basic trained model is made public. It contains ~ 35K concepts available in `MedMentions`. This was compiled from MedMentions and does not have any data from [NLM](https://www.nlm.nih.gov/research/umls/) as that data is not publicaly available. Model packs: -- MedMentions with Status (Is Concept Affirmed or Negated/Hypothetical) [Download](https://medcat.rosalind.kcl.ac.uk/media/medmen_wstatus_2021_oct.zip) +- MedMentions with Status (Is Concept Affirmed or Negated/Hypothetical) [Download](https://cogstack-medcat-example-models.s3.eu-west-2.amazonaws.com/medcat-example-models/medmen_wstatus_2021_oct.zip) Separate models: -- Vocabulary [Download](https://medcat.rosalind.kcl.ac.uk/media/vocab.dat) - Built from MedMentions -- CDB [Download](https://medcat.rosalind.kcl.ac.uk/media/cdb-medmen-v1_2.dat) - Built from MedMentions -- MetaCAT Status [Download](https://medcat.rosalind.kcl.ac.uk/media/mc_status.zip) - Built from a sample from MIMIC-III, detects is an annotation Affirmed (Positve) or Other (Negated or Hypothetical) +- Vocabulary [Download](https://cogstack-medcat-example-models.s3.eu-west-2.amazonaws.com/medcat-example-models/vocab.dat) - Built from MedMentions +- CDB [Download](https://cogstack-medcat-example-models.s3.eu-west-2.amazonaws.com/medcat-example-models/cdb-medmen-v1.dat) - Built from MedMentions +- MetaCAT Status [Download](https://cogstack-medcat-example-models.s3.eu-west-2.amazonaws.com/medcat-example-models/mc_status.zip) - Built from a sample from MIMIC-III, detects is an annotation Affirmed (Positve) or Other (Negated or Hypothetical) ## Acknowledgements Entity extraction was trained on [MedMentions](https://github.com/chanzuckerberg/MedMentions) In total it has ~ 35K entites from UMLS diff --git a/docs/requirements.txt b/docs/requirements.txt index 7e7df6e01..226900abf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,103 +2,105 @@ sphinx==6.2.1 sphinx-rtd-theme~=1.0 myst-parser~=0.17 sphinx-autoapi~=3.0.0 -MarkupSafe==2.1.3 -accelerate==0.23.0 -aiofiles==23.2.1 -aiohttp==3.8.5 +MarkupSafe==2.1.5 +accelerate==0.34.2 +aiofiles==24.1.0 +aiohttp==3.10.5 aiosignal==1.3.1 -asttokens==2.4.0 +asttokens==2.4.1 async-timeout==4.0.3 -attrs==23.1.0 +attrs==24.2.0 backcall==0.2.0 blis==0.7.11 catalogue==2.0.10 -certifi==2023.7.22 -charset-normalizer==3.3.0 +certifi==2024.8.30 +charset-normalizer==3.3.2 click==8.1.7 -comm==0.1.4 -confection==0.1.3 +comm==0.2.2 +confection==0.1.5 cymem==2.0.8 -datasets==2.14.5 +darglint==1.8.1 +datasets==2.21.0 decorator==5.1.1 -dill==0.3.7 -exceptiongroup==1.1.3 -executing==2.0.0 -filelock==3.12.4 -flake8==4.0.1 -frozenlist==1.4.0 -fsspec==2023.6.0 -gensim==4.3.2 -huggingface-hub==0.17.3 -idna==3.4 -ipython==8.16.1 -ipywidgets==8.1.1 +dill==0.3.8 +exceptiongroup==1.2.2 +executing==2.1.0 +filelock==3.16.0 +flake8==7.0.0 +frozenlist==1.4.1 +fsspec==2024.6.1 +gensim==4.3.3 +huggingface-hub==0.24.7 +idna==3.10 +ipython==8.27.0 +ipywidgets==8.1.5 jedi==0.19.1 -jinja2==3.1.2 -joblib==1.3.2 -jsonpickle==3.0.2 -jupyterlab-widgets==3.0.9 -langcodes==3.3.0 -matplotlib-inline==0.1.6 -mccabe==0.6.1 +jinja2==3.1.4 +joblib==1.4.2 +jsonpickle==3.3.0 +jupyterlab-widgets==3.0.13 +langcodes==3.4.0 +matplotlib-inline==0.1.7 +mccabe==0.7.0 mpmath==1.3.0 -multidict==6.0.4 -multiprocess==0.70.15 +multidict==6.1.0 +multiprocess==0.70.16 murmurhash==1.0.10 -mypy==1.0.0 -mypy-extensions==0.4.3 -networkx==3.1 +mypy==1.11.2 +mypy-extensions==1.0.0 +networkx==3.3 numpy==1.25.2 -packaging==23.2 -pandas==2.1.1 -parso==0.8.3 -pathy==0.10.2 -pexpect==4.8.0 +packaging==24.1 +pandas==2.2.2 +parso==0.8.4 +pathy==0.11.0 +peft==0.12.0 +pexpect==4.9.0 pickleshare==0.7.5 preshed==3.0.9 -prompt-toolkit==3.0.39 -psutil==5.9.5 +prompt-toolkit==3.0.47 +psutil==6.0.0 ptyprocess==0.7.0 -pure-eval==0.2.2 -pyarrow==13.0.0 -pycodestyle==2.8.0 -pydantic==1.10.13 -pyflakes==2.4.0 -pygments==2.16.1 -python-dateutil==2.8.2 -pytz==2023.3.post1 -pyyaml==6.0.1 -regex==2023.10.3 -requests==2.31.0 -safetensors==0.4.0 -scikit-learn==1.3.1 +pure-eval==0.2.3 +pyarrow==17.0.0 +pycodestyle==2.11.1 +pydantic==1.10.18 +pyflakes==3.2.0 +pygments==2.18.0 +python-dateutil==2.9.0 +pytz==2024.2 +pyyaml==6.0.2 +regex==2024.9.11 +requests==2.32.3 +safetensors==0.4.5 +scikit-learn==1.5.2 scipy==1.9.3 six==1.16.0 smart-open==6.4.0 -spacy==3.4.4 +spacy==3.6.1 spacy-legacy==3.0.12 spacy-loggers==1.0.5 srsly==2.4.8 stack-data==0.6.3 -sympy==1.12 +sympy==1.13.2 thinc==8.1.12 -threadpoolctl==3.2.0 -tokenizers==0.14.1 +threadpoolctl==3.5.0 +tokenizers==0.19.1 tomli==2.0.1 -torch==2.1.0 -tqdm==4.66.1 -traitlets==5.11.2 -transformers==4.34.0 -triton==2.1.0 -typer==0.7.0 +torch==2.4.1 +tqdm==4.66.5 +traitlets==5.14.3 +transformers==4.44.2 +triton==3.0.0 +typer==0.9.4 types-PyYAML==6.0.3 types-aiofiles==0.8.3 types-setuptools==57.4.10 -typing-extensions==4.8.0 -tzdata==2023.3 -urllib3==2.0.6 -wasabi==0.10.1 -wcwidth==0.2.8 -widgetsnbextension==4.0.9 -xxhash==3.4.1 -yarl==1.9.2 \ No newline at end of file +typing-extensions==4.12.2 +tzdata==2024.1 +urllib3==2.2.3 +wasabi==1.1.3 +wcwidth==0.2.13 +widgetsnbextension==4.0.13 +xxhash==3.5.0 +yarl==1.11.1 \ No newline at end of file diff --git a/examples/cdb_new.dat b/examples/cdb_new.dat deleted file mode 100644 index 27957d62b..000000000 Binary files a/examples/cdb_new.dat and /dev/null differ diff --git a/install_requires.txt b/install_requires.txt index ebd380c65..77b610825 100644 --- a/install_requires.txt +++ b/install_requires.txt @@ -1,7 +1,7 @@ 'numpy>=1.22.0,<1.26.0' # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy 'pandas>=1.4.2' # first to support 3.11 'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump -'spacy>=3.6.0,<4.0.0' # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump +'spacy>=3.6.0,<3.8.0' # 3.8 only supports numpy2 which we can't use due to other dependencies 'scipy~=1.9.2' # 1.9.2 is first to support 3.11 'transformers>=4.34.0,<5.0.0' # avoid major version bump 'accelerate>=0.23.0' # required by Trainer class in de-id @@ -21,4 +21,4 @@ 'click>=8.0.4' # allow later versions, tested with 8.1.3 'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes "humanfriendly~=10.0" # for human readable file / RAM sizes -"peft>=0.8.2" \ No newline at end of file +"peft>=0.8.2" diff --git a/medcat/cat.py b/medcat/cat.py index 621a2e831..707dbd7f3 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -1127,11 +1127,29 @@ def get_entities_multi_texts(self, self.pipe.set_error_handler(self._pipe_error_handler) try: texts_ = self._get_trimmed_texts(texts) + if self.config.general.usage_monitor.enabled: + input_lengths: List[Tuple[int, int]] = [] + for orig_text, trimmed_text in zip(texts, texts_): + if orig_text is None or trimmed_text is None: + l1, l2 = 0, 0 + else: + l1 = len(orig_text) + l2 = len(trimmed_text) + input_lengths.append((l1, l2)) docs = self.pipe.batch_multi_process(texts_, n_process, batch_size) - for doc in tqdm(docs, total=len(texts_)): + for doc_nr, doc in tqdm(enumerate(docs), total=len(texts_)): doc = None if doc.text.strip() == '' else doc out.append(self._doc_to_out(doc, only_cui, addl_info, out_with_text=True)) + if self.config.general.usage_monitor.enabled: + l1, l2 = input_lengths[doc_nr] + if doc is None: + nents = 0 + elif self.config.general.show_nested_entities: + nents = len(doc._.ents) # type: ignore + else: + nents = len(doc.ents) # type: ignore + self.usage_monitor.log_inference(l1, l2, nents) # Currently spaCy cannot mark which pieces of texts failed within the pipe so be this workaround, # which also assumes texts are different from each others. @@ -1637,6 +1655,9 @@ def _mp_cons(self, in_q: Queue, out_list: List, min_free_memory: float, logger.warning("PID: %s failed one document in _mp_cons, running will continue normally. \n" + "Document length in chars: %s, and ID: %s", pid, len(str(text)), i_text) logger.warning(str(e)) + if self.config.general.usage_monitor.enabled: + # NOTE: This is in another process, so need to explicitly flush + self.usage_monitor._flush_logs() sleep(2) def _add_nested_ent(self, doc: Doc, _ents: List[Span], _ent: Union[Dict, Span]) -> None: diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 8c73e6178..386bbe0cf 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -257,20 +257,19 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data category_value2id = g_config['category_value2id'] if not category_value2id: # Encode the category values - data_undersampled, full_data, category_value2id = encode_category_values(data, + full_data, data_undersampled, category_value2id = encode_category_values(data, category_undersample=self.config.model.category_undersample) g_config['category_value2id'] = category_value2id else: # We already have everything, just get the data - data_undersampled, full_data, category_value2id = encode_category_values(data, + full_data, data_undersampled, category_value2id = encode_category_values(data, existing_category_value2id=category_value2id, category_undersample=self.config.model.category_undersample) g_config['category_value2id'] = category_value2id # Make sure the config number of classes is the same as the one found in the data if len(category_value2id) != self.config.model['nclasses']: logger.warning( - "The number of classes set in the config is not the same as the one found in the data: {} vs {}".format( - self.config.model['nclasses'], len(category_value2id))) + "The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id)) logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") self.config.model['nclasses'] = len(category_value2id) diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 32eb23520..1de8d6d83 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -4,7 +4,7 @@ import datasets from spacy.tokens import Doc from datetime import datetime -from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable +from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type from spacy.tokens import Span import inspect from functools import partial @@ -87,7 +87,13 @@ def create_eval_pipeline(self): # NOTE: this will fix the DeID model(s) created before medcat 1.9.3 # though this fix may very well be unstable self.ner_pipe.tokenizer._in_target_context_manager = False + if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'): + # NOTE: this will fix the DeID model(s) created with transformers before 4.42 + # and allow them to run with later transforemrs + self.ner_pipe.tokenizer.split_special_tokens = False self.ner_pipe.device = self.model.device + self._consecutive_identical_failures = 0 + self._last_exception: Optional[Tuple[str, Type[Exception]]] = None def get_hash(self) -> str: """A partial hash trying to catch differences between models. @@ -390,34 +396,33 @@ def _process(self, #all_text_processed = self.tokenizer.encode_eval(all_text) # For now we will process the documents one by one, should be improved in the future to use batching for doc in docs: - try: - res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy']) - doc.ents = [] # type: ignore - for r in res: - inds = [] - for ind, word in enumerate(doc): - end_char = word.idx + len(word.text) - if end_char <= r['end'] and end_char > r['start']: - inds.append(ind) - # To not loop through everything - if end_char > r['end']: - break - if inds: - entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group']) - entity._.cui = r['entity_group'] - entity._.context_similarity = r['score'] - entity._.detected_name = r['word'] - entity._.id = len(doc._.ents) - entity._.confidence = r['score'] - - doc._.ents.append(entity) - create_main_ann(self.cdb, doc) - if self.cdb.config.general['make_pretty_labels'] is not None: - make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']]) - if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}): - map_ents_to_groups(self.cdb, doc) - except Exception as e: - logger.warning(e, exc_info=True) + res = self.ner_pipe(doc.text, aggregation_strategy=self.config.general['ner_aggregation_strategy']) + doc.ents = [] # type: ignore + for r in res: + inds = [] + for ind, word in enumerate(doc): + end_char = word.idx + len(word.text) + if end_char <= r['end'] and end_char > r['start']: + inds.append(ind) + # To not loop through everything + if end_char > r['end']: + break + if inds: + entity = Span(doc, min(inds), max(inds) + 1, label=r['entity_group']) + entity._.cui = r['entity_group'] + entity._.context_similarity = r['score'] + entity._.detected_name = r['word'] + entity._.id = len(doc._.ents) + entity._.confidence = r['score'] + + doc._.ents.append(entity) + create_main_ann(self.cdb, doc) + if self.cdb.config.general['make_pretty_labels'] is not None: + make_pretty_labels(self.cdb, doc, LabelStyle[self.cdb.config.general['make_pretty_labels']]) + if self.cdb.config.general['map_cui_to_group'] is not None and self.cdb.addl_info.get('cui2group', {}): + map_ents_to_groups(self.cdb, doc) + self._consecutive_identical_failures = 0 # success + self._last_exception = None yield from docs # Override diff --git a/medcat/ner/vocab_based_ner.py b/medcat/ner/vocab_based_ner.py index 97a24dca1..259699ff6 100644 --- a/medcat/ner/vocab_based_ner.py +++ b/medcat/ner/vocab_based_ner.py @@ -42,13 +42,21 @@ def __call__(self, doc: Doc) -> Doc: name_versions = [tkn._.norm, tkn.lower_] name = "" + nv_in_snames = [] + nv_in_names = [] for name_version in name_versions: + # NOTE: if the entire token is an actual concept, we want to capture that + # previous implementation could fail in those cases if name_version in self.cdb.snames: - if name: - name = name + self.config.general.separator + name_version - else: - name = name_version - break + nv_in_snames.append(name_version) + if name_version in self.cdb.name2cuis: + nv_in_names.append(name_version) + if nv_in_names: + # TODO: should we prefer 0th (i.e the normalised version) or last (the lower case version) + name = nv_in_names[0] + elif nv_in_snames: + # TODO: should we prefer 0th (i.e the normalised version) or last (the lower case version) + name = nv_in_snames[0] if name in self.cdb.name2cuis and not tkn.is_stop: maybe_annotate_name(name, tkns, doc, self.cdb, self.config) diff --git a/medcat/utils/meta_cat/data_utils.py b/medcat/utils/meta_cat/data_utils.py index 17059d7f4..3d0431308 100644 --- a/medcat/utils/meta_cat/data_utils.py +++ b/medcat/utils/meta_cat/data_utils.py @@ -166,12 +166,12 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict Name of class that should be used to undersample the data (for 2 phase learning) Returns: - dict: - New underesampled data (for 2 phase learning) with integers inplace of strings for category values dict: New data with integers inplace of strings for category values. dict: - Map rom category value to ID for all categories in the data. + New undersampled data (for 2 phase learning) with integers inplace of strings for category values + dict: + Map from category value to ID for all categories in the data. """ data = list(data) if existing_category_value2id is not None: @@ -180,23 +180,6 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict category_value2id = {} category_values = set([x[2] for x in data]) - # Ensuring that each label has data and checking for class imbalance - - label_data = {key: 0 for key in category_value2id} - for i in range(len(data)): - if data[i][2] in category_value2id: - label_data[data[i][2]] = label_data[data[i][2]] + 1 - - # If a label has no data, changing the mapping - if 0 in label_data.values(): - category_value2id_: Dict = {} - keys_ls = [key for key, value in category_value2id.items() if value != 0] - for k in keys_ls: - category_value2id_[k] = len(category_value2id_) - - logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping:", category_value2id_) - category_value2id = category_value2id_ - for c in category_values: if c not in category_value2id: category_value2id[c] = len(category_value2id) @@ -210,6 +193,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict for i in range(len(data)): if data[i][2] in category_value2id.values(): label_data_[data[i][2]] = label_data_[data[i][2]] + 1 + + logger.info("Original label_data: %s",label_data_) # Undersampling data if category_undersample is None or category_undersample == '': min_label = min(label_data_.values()) @@ -232,9 +217,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict for i in range(len(data_undersampled)): if data_undersampled[i][2] in category_value2id.values(): label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1 - logger.info(f"Updated label_data: {label_data}") + logger.info("Updated label_data: %s",label_data) - return data_undersampled, data, category_value2id + return data, data_undersampled, category_value2id def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable: diff --git a/medcat/utils/meta_cat/ml_utils.py b/medcat/utils/meta_cat/ml_utils.py index 3559ce1d8..a7acf34c9 100644 --- a/medcat/utils/meta_cat/ml_utils.py +++ b/medcat/utils/meta_cat/ml_utils.py @@ -66,7 +66,7 @@ class label of the data x = torch.tensor(x, dtype=torch.long).to(device) # cpos = torch.tensor(cpos, dtype=torch.long).to(device) - attention_masks = (x != 0).type(torch.int) + attention_masks = (x != pad_id).type(torch.int) return x, cpos, attention_masks, y @@ -201,7 +201,7 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa y_ = [x[2] for x in train_data] class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_) config.train['class_weights'] = class_weights.tolist() - logger.info(f"Class weights computed: {class_weights}") + logger.info("Class weights computed: %s",class_weights) class_weights = torch.FloatTensor(class_weights).to(device) if config.train['loss_funct'] == 'cross_entropy': @@ -259,7 +259,7 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4): # Total number of training steps total_steps = int((len(data_) / batch_size_) * epochs) - logger.info('Total steps for optimizer: {}'.format(total_steps)) + logger.info('Total steps for optimizer: %d',total_steps) # Set up the learning rate scheduler scheduler_ = get_linear_schedule_with_warmup(optimizer_, @@ -412,10 +412,16 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average) labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x: x[1])] + labels_present_: set = set(predictions) + labels_present: List[str] = [str(x) for x in labels_present_] + + if len(labels) != len(labels_present): + logger.warning( + "The evaluation dataset does not contain all the labels, some labels are missing. Performance displayed for labels found...") confusion = pd.DataFrame( data=confusion_matrix(y_eval, predictions, ), - columns=["true " + label for label in labels], - index=["predicted " + label for label in labels], + columns=["true " + label for label in labels_present], + index=["predicted " + label for label in labels_present], ) examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}} diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py index 774cabff5..543e0ca6b 100644 --- a/medcat/utils/meta_cat/models.py +++ b/medcat/utils/meta_cat/models.py @@ -91,7 +91,7 @@ def __init__(self, config): super(BertForMetaAnnotation, self).__init__() _bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers']) if config.model['input_size'] != _bertconfig.hidden_size: - logger.warning(f"\nInput size for {config.model.model_variant} model should be {_bertconfig.hidden_size}, provided input size is {config.model['input_size']} Input size changed to {_bertconfig.hidden_size}") + logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size) bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig) self.config = config diff --git a/medcat/utils/normalizers.py b/medcat/utils/normalizers.py index 532075961..4e713f3a3 100644 --- a/medcat/utils/normalizers.py +++ b/medcat/utils/normalizers.py @@ -82,17 +82,22 @@ def known(self, words: Iterable[str]) -> Set[str]: return set(w for w in words if w in self.vocab) def edits1(self, word: str) -> Set[str]: + return self.get_edits1(word, self.config.general.diacritics) + + @classmethod + def get_edits1(cls, word: str, use_diacritics: bool) -> Set[str]: """All edits that are one edit away from `word`. Args: word (str): The word. + use_diacritics (bool): Whether to use diacritics or not. Returns: Set[str]: The set of all edits """ letters = 'abcdefghijklmnopqrstuvwxyz' - if self.config.general.diacritics: + if use_diacritics: letters += 'àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ' splits = [(word[:i], word[i:]) for i in range(len(word) + 1)] @@ -119,6 +124,41 @@ def edits3(self, word): pass +def get_all_edits_n(word: str, use_diacritics: bool, n: int, + return_ordered: bool = False) -> Iterator[str]: + """Get all N-th order edits of a word. + + The output can be ordered. This can be useful when run-to-run + is of concern. But by default this should be avoided where possible + since it adds overhead and limits the operations permitted on the + returned value (i.e for distance 1, in unordered case you get a set). + + Args: + word (str): The original word. + use_diacritics (bool): Whether or not to use diacritics. + n (int): The number of edits to allow. + return_ordered (bool): Whether to order the output. Defaults to False. + + Raises: + ValueError: If the number of edits is smaller than 0. + + Yields: + Iterator[str]: The generator of the various edits. + """ + if n < 0: + raise ValueError(f"Unknown edit count: {n}") + if n == 0: + yield word + return + edits = BasicSpellChecker.get_edits1(word, use_diacritics) + f_edits = sorted(edits) if return_ordered else edits + if n == 1: + yield from f_edits + return + for edited_word in f_edits: + yield from get_all_edits_n(edited_word, use_diacritics, n - 1, return_ordered) + + class TokenNormalizer(PipeRunner): """Will normalize all tokens in a spacy document. diff --git a/medcat/utils/preprocess_snomed.py b/medcat/utils/preprocess_snomed.py index ac78548a7..dc6f4c3a3 100644 --- a/medcat/utils/preprocess_snomed.py +++ b/medcat/utils/preprocess_snomed.py @@ -3,6 +3,9 @@ import re import hashlib import pandas as pd +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field +from enum import Enum, auto def parse_file(filename, first_row_header=True, columns=None): @@ -61,6 +64,174 @@ def get_direct_refset_mapping(in_dict: dict) -> dict: return ret_dict + + +_IGNORE_TAG = '##IGNORE-THIS##' + + +class RefSetFileType(Enum): + concept = auto() + description = auto() + relationship = auto() + refset = auto() + + +@dataclass +class FileFormatDescriptor: + concept: str + description: str + relationship: str + refset: str + common_prefix: str = "sct2_" # for concept, description, and relationship (but not refset) + + @classmethod + def ignore_all(cls) -> 'FileFormatDescriptor': + return cls(concept=_IGNORE_TAG, description=_IGNORE_TAG, + relationship=_IGNORE_TAG, refset=_IGNORE_TAG) + + def get_file_per_type(self, file_type: RefSetFileType) -> str: + raw = self._get_raw(file_type) + return raw if file_type == RefSetFileType.refset else self.common_prefix + raw + + def _get_raw(self, file_type: RefSetFileType) -> str: + return getattr(self, file_type.name) + + def get_concept(self) -> str: + return self.get_file_per_type(RefSetFileType.concept) + + def get_description(self) -> str: + return self.get_file_per_type(RefSetFileType.description) + + def get_relationship(self) -> str: + return self.get_file_per_type(RefSetFileType.relationship) + + def get_refset(self) -> str: + return self.get_file_per_type(RefSetFileType.refset) + + +@dataclass +class ExtensionDescription: + exp_name_in_folder: str + exp_files: FileFormatDescriptor + exp_2nd_part_in_folder: Optional[str] = None + + +# pattern has: EXTENSION PRODUCTION RELEASE +SNOMED_FOLDER_NAME_PATTERN = re.compile("^SnomedCT_([A-Za-z0-9]+)_([A-Za-z0-9]+)_(\d{8}T\d{6}Z$)") +PER_FILE_TYPE_PATHS = { + RefSetFileType.concept: os.path.join("Snapshot", "Terminology"), + RefSetFileType.description: os.path.join("Snapshot", "Terminology"), + RefSetFileType.relationship: os.path.join("Snapshot", "Terminology"), + RefSetFileType.refset: os.path.join("Snapshot", "Refset", "Map"), +} + + + +class SupportedExtension(Enum): + INTERNATIONAL = ExtensionDescription( + exp_name_in_folder="InternationalRF2", + exp_files=FileFormatDescriptor( + concept="Concept_Snapshot", + description="Description_Snapshot-en", + relationship="Relationship_Snapshot", + # NOTE: the below will be ignored for UK_CLIN bundle + refset="der2_iisssccRefset_ExtendedMapSnapshot" + ), + ) + UK_CLINICAL = ExtensionDescription( + exp_name_in_folder="UKClinicalRF2", + exp_files=FileFormatDescriptor( + concept="Concept_UKCLSnapshot", + description="Description_UKCLSnapshot-en", + relationship="Relationship_UKCLSnapshot", + refset="der2_iisssciRefset_ExtendedMapUKCLSnapshot" + ), + ) + UK_CLINICAL_REFSET = ExtensionDescription( + exp_name_in_folder="UKClinicalRefsetsRF2", + exp_files=FileFormatDescriptor.ignore_all() + ) + UK_EDITION = ExtensionDescription( + exp_name_in_folder="UKEditionRF2", + exp_files=FileFormatDescriptor( + concept="Concept_UKEDSnapshot", + description="Description_UKEDSnapshot-en", + relationship="Relationship_UKEDSnapshot", + refset="der2_iisssciRefset_ExtendedMapUKEDSnapshot" + ), + ) + UK_DRUG = ExtensionDescription( + exp_name_in_folder="UKDrugRF2", + exp_files=FileFormatDescriptor( + concept="Concept_UKDGSnapshot", + description="Description_UKDGSnapshot-en", + relationship="Relationship_UKDGSnapshot", + refset="der2_iisssciRefset_ExtendedMapUKDGSnapshot", + ), + ) + AU = ExtensionDescription( + exp_name_in_folder="Release", + exp_2nd_part_in_folder="AU1000036", + exp_files=FileFormatDescriptor( + concept="Concept_Snapshot", + description="Description_Snapshot-en-AU", + relationship="Relationship_Snapshot", + refset=_IGNORE_TAG, + ), + ) + + +@dataclass +class BundleDescriptor: + extensions: List[SupportedExtension] + ignores: Dict[RefSetFileType, List[SupportedExtension]] = field(default_factory=dict) + + def has_invalid(self, ext: SupportedExtension, file_types: Tuple[RefSetFileType]) -> bool: + for ft in file_types: + if ft not in self.ignores: + continue + exts2ignore = self.ignores[ft] + if ext in exts2ignore: + return True + return False + + +class SupportedBundles(Enum): + UK_CLIN = BundleDescriptor( + extensions=[SupportedExtension.INTERNATIONAL, SupportedExtension.UK_CLINICAL, + SupportedExtension.UK_CLINICAL_REFSET, SupportedExtension.UK_EDITION], + ignores={RefSetFileType.refset: [SupportedExtension.INTERNATIONAL]} + ) + UK_DRUG_EXT = BundleDescriptor( + extensions=[SupportedExtension.UK_DRUG, SupportedExtension.UK_EDITION], + ) + + +def match_partials_with_folders(exp_names: List[Tuple[str, Optional[str]]], + folder_names: List[str], + _group_nr1: int = 1, _group_nr2: int = 2) -> bool: + if len(exp_names) > len(folder_names): + return False + available_folders = [os.path.basename(f) for f in folder_names] + for exp_name, exp_name_p2 in exp_names: + found_cur_name = False + for fi, folder in enumerate(available_folders): + m = SNOMED_FOLDER_NAME_PATTERN.match(folder) + if not m: + continue + if m.group(_group_nr1) != exp_name: + continue + if exp_name_p2 and m.group(_group_nr2) != exp_name_p2: + continue + found_cur_name = True + break + if found_cur_name: + available_folders.pop(fi) + else: + return False + return True + + class Snomed: """ Pre-process SNOMED CT release files. @@ -74,26 +245,73 @@ class Snomed: uk_drug_ext (bool, optional): Specifies whether the version is a SNOMED UK drug extension. Defaults to False. au_ext (bool, optional): Specifies whether the version is a AU release. Defaults to False. """ + NO_VERSION_DETECTED = 'N/A' - def __init__(self, data_path, uk_ext=False, uk_drug_ext=False, au_ext: bool = False): + def __init__(self, data_path): self.data_path = data_path - self.release = data_path[-16:-8] - self.uk_ext = uk_ext - self.uk_drug_ext = uk_drug_ext - self.opcs_refset_id = "1126441000000105" - if ((self.uk_ext or self.uk_drug_ext) and + self.bundle = self._determine_bundle(self.data_path) + self.paths, self.snomed_releases, self.exts = self._check_path_and_release() + + @classmethod + def _determine_bundle(cls, data_path) -> Optional[SupportedBundles]: + if not os.path.exists(data_path) or not os.path.isdir(data_path): + return None + for bundle in SupportedBundles: + folder_names = list(os.listdir(data_path)) + exp_names = [(ext.value.exp_name_in_folder, ext.value.exp_2nd_part_in_folder) + for ext in bundle.value.extensions] + if match_partials_with_folders(exp_names, folder_names): + return bundle + return None + + def _set_extension(self, release: str, extension: SupportedExtension) -> None: + # NOTE: now using the later refset IF by default + # NOTE: the OPCS4 refset ID is only relevant for UK releases + self.opcs_refset_id = '1382401000000109' + if (extension in (SupportedExtension.UK_CLINICAL, SupportedExtension.UK_DRUG) and # using lexicographical comparison below # e.g "20240101" > "20231122" results in True # yet "20231121" > "20231122" results in False - len(self.release) == len("20231122") and self.release >= "20231122"): + len(release) == len("20231122") and release < "20231122"): # NOTE for UK extensions starting from 20231122 the # OPCS4 refset ID seems to be different - self.opcs_refset_id = '1382401000000109' - self.au_ext = au_ext - # validate - if (self.uk_ext or self.uk_drug_ext) and self.au_ext: - raise ValueError("Cannot both be a UK and and a AU version. " - f"Got UK={uk_ext}, UK_Drug={uk_drug_ext}, AU={au_ext}") + self.opcs_refset_id = "1126441000000105" + self._extension = extension + + @classmethod + def _determine_extension(cls, folder_path: str, + _group_nr1: int = 1, _group_nr2: int = 2) -> SupportedExtension: + folder_basename = os.path.basename(folder_path) + m = SNOMED_FOLDER_NAME_PATTERN.match(folder_basename) + if not m: + raise UnkownSnomedReleaseException( + f"Unable to determine extension for path {repr(folder_path)}. " + f"Checking against pattern {SNOMED_FOLDER_NAME_PATTERN}") + ext_str = m.group(_group_nr1) + ext_str2 = m.group(_group_nr2) + for extension in SupportedExtension: + if extension.value.exp_name_in_folder != ext_str: + continue + if (extension.value.exp_2nd_part_in_folder and + extension.value.exp_2nd_part_in_folder != ext_str2): + continue + return extension + ext_names_folders = ",".join([f"{ext.name} ({ext.value.exp_name_in_folder})" + for ext in SupportedExtension]) + raise UnkownSnomedReleaseException( + f"Cannot Find the extension for {folder_path}. " + f"Tried the following extensions: {ext_names_folders}") + + @classmethod + def _determine_release(cls, folder_path: str, strict: bool = True, + _group_nr: int = 3, _keep_chars: int = 8) -> str: + folder_basename = os.path.basename(folder_path) + match = SNOMED_FOLDER_NAME_PATTERN.match(folder_basename) + if match is None and strict: + raise UnkownSnomedReleaseException(f"No version found in '{folder_path}'") + elif match is None: + return cls.NO_VERSION_DETECTED + return match.group(_group_nr)[:_keep_chars] def to_concept_df(self): """ @@ -106,37 +324,17 @@ def to_concept_df(self): Returns: pandas.DataFrame: SNOMED CT concept DataFrame. """ - paths, snomed_releases = self._check_path_and_release() df2merge = [] - for i, snomed_release in enumerate(snomed_releases): - contents_path = os.path.join(paths[i], "Snapshot", "Terminology") - concept_snapshot = "sct2_Concept_Snapshot" - description_snapshot = "sct2_Description_Snapshot-en" - if self.au_ext: - description_snapshot += "-AU" - if self.uk_ext: - if "SnomedCT_UKClinicalRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKCLSnapshot" - description_snapshot = "sct2_Description_UKCLSnapshot-en" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKEDSnapshot" - description_snapshot = "sct2_Description_UKEDSnapshot-en" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass - if self.uk_drug_ext: - if "SnomedCT_UKDrugRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKDGSnapshot" - description_snapshot = "sct2_Description_UKDGSnapshot-en" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKEDSnapshot" - description_snapshot = "sct2_Description_UKEDSnapshot-en" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass + for i, snomed_release in enumerate(self.snomed_releases): + self._set_extension(snomed_release, self.exts[i]) + contents_path = os.path.join(self.paths[i], PER_FILE_TYPE_PATHS[RefSetFileType.concept]) + concept_snapshot = self._extension.value.exp_files.get_concept() + description_snapshot = self._extension.value.exp_files.get_description() + if concept_snapshot is None or _IGNORE_TAG in concept_snapshot or ( + self.bundle and self.bundle.value.has_invalid( + self._extension, [RefSetFileType.concept, RefSetFileType.description])): + continue for f in os.listdir(contents_path): m = re.search(f'{concept_snapshot}'+r'_(.*)_\d*.txt', f) @@ -202,37 +400,16 @@ def list_all_relationships(self): Returns: list: List of all SNOMED CT relationships. """ - paths, snomed_releases = self._check_path_and_release() all_rela = [] - for i, snomed_release in enumerate(snomed_releases): - contents_path = os.path.join(paths[i], "Snapshot", "Terminology") - concept_snapshot = "sct2_Concept_Snapshot" - relationship_snapshot = "sct2_Relationship_Snapshot" - if self.uk_ext: - if "SnomedCT_InternationalRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_Snapshot" - relationship_snapshot = "sct2_Relationship_Snapshot" - elif "SnomedCT_UKClinicalRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKCLSnapshot" - relationship_snapshot = "sct2_Relationship_UKCLSnapshot" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKEDSnapshot" - relationship_snapshot = "sct2_Relationship_UKEDSnapshot" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass - if self.uk_drug_ext: - if "SnomedCT_UKDrugRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKDGSnapshot" - relationship_snapshot = "sct2_Relationship_UKDGSnapshot" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKEDSnapshot" - relationship_snapshot = "sct2_Relationship_UKEDSnapshot" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass + for i, snomed_release in enumerate(self.snomed_releases): + self._set_extension(snomed_release, self.exts[i]) + contents_path = os.path.join(self.paths[i], PER_FILE_TYPE_PATHS[RefSetFileType.concept]) + concept_snapshot = self._extension.value.exp_files.get_concept() + relationship_snapshot = self._extension.value.exp_files.get_relationship() + if concept_snapshot is None or _IGNORE_TAG in concept_snapshot or ( + self.bundle and self.bundle.value.has_invalid( + self._extension, [RefSetFileType.concept, RefSetFileType.description])): + continue for f in os.listdir(contents_path): m = re.search(f'{concept_snapshot}'+r'_(.*)_\d*.txt', f) @@ -259,37 +436,16 @@ def relationship2json(self, relationshipcode, output_jsonfile): Returns: file: JSON file of relationship mapping. """ - paths, snomed_releases = self._check_path_and_release() output_dict = {} - for i, snomed_release in enumerate(snomed_releases): - contents_path = os.path.join(paths[i], "Snapshot", "Terminology") - concept_snapshot = "sct2_Concept_Snapshot" - relationship_snapshot = "sct2_Relationship_Snapshot" - if self.uk_ext: - if "SnomedCT_InternationalRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_Snapshot" - relationship_snapshot = "sct2_Relationship_Snapshot" - elif "SnomedCT_UKClinicalRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKCLSnapshot" - relationship_snapshot = "sct2_Relationship_UKCLSnapshot" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKEDSnapshot" - relationship_snapshot = "sct2_Relationship_UKEDSnapshot" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass - if self.uk_drug_ext: - if "SnomedCT_UKDrugRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKDGSnapshot" - relationship_snapshot = "sct2_Relationship_UKDGSnapshot" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - concept_snapshot = "sct2_Concept_UKEDSnapshot" - relationship_snapshot = "sct2_Relationship_UKEDSnapshot" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass + for i, snomed_release in enumerate(self.snomed_releases): + self._set_extension(snomed_release, self.exts[i]) + contents_path = os.path.join(self.paths[i], PER_FILE_TYPE_PATHS[RefSetFileType.concept]) + concept_snapshot = self._extension.value.exp_files.get_concept() + relationship_snapshot = self._extension.value.exp_files.get_relationship() + if concept_snapshot is None or _IGNORE_TAG in concept_snapshot or ( + self.bundle and self.bundle.value.has_invalid( + self._extension, [RefSetFileType.concept, RefSetFileType.description])): + continue for f in os.listdir(contents_path): m = re.search(f'{concept_snapshot}'+r'_(.*)_\d*.txt', f) @@ -322,10 +478,7 @@ def map_snomed2icd10(self): dict: A dictionary containing the SNOMED CT to ICD-10 mappings including metadata. """ snomed2icd10df = self._map_snomed2refset() - if self.uk_ext is True: - return self._refset_df2dict(snomed2icd10df[0]) - else: - return self._refset_df2dict(snomed2icd10df) + return self._refset_df2dict(snomed2icd10df[0]) def map_snomed2opcs4(self) -> dict: """ @@ -340,7 +493,8 @@ def map_snomed2opcs4(self) -> dict: Returns: dict: A dictionary containing the SNOMED CT to OPCS-4 mappings including metadata. """ - if self.uk_ext is not True: + if all(ext not in (SupportedExtension.UK_CLINICAL, SupportedExtension.UK_DRUG) + for ext in self.exts): raise AttributeError( "OPCS-4 mapping does not exist in this edition") snomed2opcs4df = self._map_snomed2refset()[1] @@ -361,17 +515,21 @@ def _check_path_and_release(self): """ snomed_releases = [] paths = [] + exts = [] if "Snapshot" in os.listdir(self.data_path): paths.append(self.data_path) - snomed_releases.append(self.release) + snomed_releases.append(self._determine_release(self.data_path, strict=True)) + exts.append(self._determine_extension(self.data_path)) else: for folder in os.listdir(self.data_path): if "SnomedCT" in folder: paths.append(os.path.join(self.data_path, folder)) - snomed_releases.append(folder[-16:-8]) + rel = self._determine_release(folder, strict=True) + snomed_releases.append(rel) + exts.append(self._determine_extension(paths[-1])) if len(paths) == 0: raise FileNotFoundError('Incorrect path to SNOMED CT directory') - return paths, snomed_releases + return paths, snomed_releases, exts def _refset_df2dict(self, refset_df: pd.DataFrame) -> dict: """ @@ -403,31 +561,15 @@ def _map_snomed2refset(self): OR tuple: Tuple of dataframes containing SNOMED CT to refset mappings and metadata (ICD-10, OPCS4), if uk_ext is True. """ - paths, snomed_releases = self._check_path_and_release() dfs2merge = [] - for i, snomed_release in enumerate(snomed_releases): - refset_terminology = f'{paths[i]}/Snapshot/Refset/Map' - icd10_ref_set = 'der2_iisssccRefset_ExtendedMapSnapshot' - if self.uk_ext: - if "SnomedCT_InternationalRF2_PRODUCTION" in paths[i]: - continue - elif "SnomedCT_UKClinicalRF2_PRODUCTION" in paths[i]: - icd10_ref_set = "der2_iisssciRefset_ExtendedMapUKCLSnapshot" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - icd10_ref_set = "der2_iisssciRefset_ExtendedMapUKEDSnapshot" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass - if self.uk_drug_ext: - if "SnomedCT_UKDrugRF2_PRODUCTION" in paths[i]: - icd10_ref_set = "der2_iisssciRefset_ExtendedMapUKDGSnapshot" - elif "SnomedCT_UKEditionRF2_PRODUCTION" in paths[i]: - icd10_ref_set = "der2_iisssciRefset_ExtendedMapUKEDSnapshot" - elif "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION" in paths[i]: - continue - else: - pass + for i, snomed_release in enumerate(self.snomed_releases): + self._set_extension(snomed_release, self.exts[i]) + refset_terminology = os.path.join(self.paths[i], PER_FILE_TYPE_PATHS[RefSetFileType.refset]) + icd10_ref_set = self._extension.value.exp_files.get_refset() + if icd10_ref_set is None or _IGNORE_TAG in icd10_ref_set or ( + self.bundle and self.bundle.value.has_invalid( + self._extension, [RefSetFileType.concept, RefSetFileType.description])): + continue for f in os.listdir(refset_terminology): m = re.search(f'{icd10_ref_set}'+r'_(.*)_\d*.txt', f) if m: @@ -440,10 +582,17 @@ def _map_snomed2refset(self): dfs2merge.append(icd_mappings) mapping_df = pd.concat(dfs2merge) del dfs2merge - if self.uk_ext or self.uk_drug_ext: + if any(ext in (SupportedExtension.UK_CLINICAL, SupportedExtension.UK_DRUG) + for ext in self.exts): opcs_df = mapping_df[mapping_df['refsetId'] == self.opcs_refset_id] icd10_df = mapping_df[mapping_df['refsetId'] == '999002271000000101'] return icd10_df, opcs_df else: - return mapping_df + return mapping_df, None + + +class UnkownSnomedReleaseException(ValueError): + + def __init__(self, *args) -> None: + super().__init__(*args) diff --git a/medcat/utils/regression/checking.py b/medcat/utils/regression/checking.py index d3c425583..2c2d52ce9 100644 --- a/medcat/utils/regression/checking.py +++ b/medcat/utils/regression/checking.py @@ -12,7 +12,9 @@ from medcat.utils.regression.targeting import TranslationLayer, OptionSet from medcat.utils.regression.targeting import FinalTarget, TargetedPhraseChanger from medcat.utils.regression.utils import partial_substitute, MedCATTrainerExportConverter +from medcat.utils.regression.utils import pick_random_edits from medcat.utils.regression.results import MultiDescriptor, ResultDescriptor, Finding +from medcat.utils.normalizers import get_all_edits_n logger = logging.getLogger(__name__) @@ -69,7 +71,9 @@ def check_specific_for_phrase(self, cat: CAT, target: FinalTarget, def estimate_num_of_diff_subcases(self) -> int: return len(self.phrases) * self.options.estimate_num_of_subcases() - def get_distinct_cases(self, translation: TranslationLayer) -> Iterator[Iterator[FinalTarget]]: + def get_distinct_cases(self, translation: TranslationLayer, + edit_distance: Tuple[int, int, int], + use_diacritics: bool) -> Iterator[Iterator[FinalTarget]]: """Gets the various distinct sub-case iterators. The sub-cases are those that can be determine without the translation layer. @@ -77,6 +81,8 @@ def get_distinct_cases(self, translation: TranslationLayer) -> Iterator[Iterator Args: translation (TranslationLayer): The translation layer. + edit_distance (Tuple[int, int, int]): The edit distance(s) to try. + use_diacritics (bool): Whether to use diacritics for edit distance. Yields: Iterator[Iterator[FinalTarget]]: The iterator of iterators of different sub cases. @@ -84,24 +90,43 @@ def get_distinct_cases(self, translation: TranslationLayer) -> Iterator[Iterator # for each phrase and for each placeholder based option for changer in self.options.get_preprocessors_and_targets(translation): for phrase in self.phrases: - yield self._get_subcases(phrase, changer, translation) + yield self._get_subcases(phrase, changer, translation, edit_distance, use_diacritics) def _get_subcases(self, phrase: str, changer: TargetedPhraseChanger, - translation: TranslationLayer) -> Iterator[FinalTarget]: + translation: TranslationLayer, + edit_distance: Tuple[int, int, int], + use_diacritics: bool, + ) -> Iterator[FinalTarget]: cui, placeholder = changer.cui, changer.placeholder changed_phrase = changer.changer(phrase) - for name in translation.get_names_of(cui, changer.onlyprefnames): - num_of_phs = changed_phrase.count(placeholder) - if num_of_phs == 1: - yield FinalTarget(placeholder=placeholder, - cui=cui, name=name, - final_phrase=changed_phrase) - continue - for cntr in range(num_of_phs): - final_phrase = partial_substitute(changed_phrase, placeholder, name, cntr) - yield FinalTarget(placeholder=placeholder, - cui=cui, name=name, - final_phrase=final_phrase) + edit_dist, edit_rn_seed, edit_pick = edit_distance + for raw_name in translation.get_names_of(cui, changer.onlyprefnames): + name_variant = 0 + if edit_dist:# TODO: use config.ner.min_name_len or something + name_gen = get_all_edits_n( + raw_name, use_diacritics, edit_dist, return_ordered=True) + all_names = list(pick_random_edits(name_gen, edit_pick, len(raw_name), + edit_dist, edit_rn_seed)) + else: + all_names = [raw_name] + for name in all_names: + if edit_dist: + logger.debug("Changed name from '%s' to '%s' (variant %d, edit distance %s, " + "seed %d, picking %d)", + raw_name, name, name_variant, edit_dist, + edit_rn_seed, edit_pick) + name_variant += 1 + num_of_phs = changed_phrase.count(placeholder) + if num_of_phs == 1: + yield FinalTarget(placeholder=placeholder, + cui=cui, name=name, + final_phrase=changed_phrase) + continue + for cntr in range(num_of_phs): + final_phrase = partial_substitute(changed_phrase, placeholder, name, cntr) + yield FinalTarget(placeholder=placeholder, + cui=cui, name=name, + final_phrase=final_phrase) def to_dict(self) -> dict: """Converts the RegressionCase to a dict for serialisation. @@ -293,7 +318,9 @@ def __init__(self, cases: List[RegressionCase], metadata: MetaData, name: str) - for case in self.cases: self.report.parts.append(case.report) - def get_all_distinct_cases(self, translation: TranslationLayer + def get_all_distinct_cases(self, translation: TranslationLayer, + edit_distance: Tuple[int, int, int], + use_diacritics: bool ) -> Iterator[Tuple[RegressionCase, Iterator[FinalTarget]]]: """Gets all the distinct cases for this regression suite. @@ -302,13 +329,17 @@ def get_all_distinct_cases(self, translation: TranslationLayer Args: translation (TranslationLayer): The translation layer. + edit_distance (Tuple[int, int, int]): The edit distance(s) to try. + Defaults to (0, 0, 0). + use_diacritics (bool): Whether to use diacritics for edit distance. Yields: Iterator[Tuple[RegressionCase, Iterator[FinalTarget]]]: The generator of the regression case along with its corresponding sub-cases. """ for regr_case in self.cases: - for subcase in regr_case.get_distinct_cases(translation): + for subcase in regr_case.get_distinct_cases(translation, edit_distance, + use_diacritics): yield regr_case, subcase def estimate_total_distinct_cases(self) -> int: @@ -316,6 +347,8 @@ def estimate_total_distinct_cases(self) -> int: def iter_subcases(self, translation: TranslationLayer, show_progress: bool = True, + edit_distance: Tuple[int, int, int] = (0, 0, 0), + use_diacritics: bool = False, ) -> Iterator[Tuple[RegressionCase, FinalTarget]]: """Iterate over all the sub-cases. @@ -325,28 +358,40 @@ def iter_subcases(self, translation: TranslationLayer, Args: translation (TranslationLayer): The translation layer. show_progress (bool): Whether to show progress. Defaults to True. + edit_distance (Tuple[int, int, int]): The edit distance(s) to try. + Defaults to (0, 0, 0). + use_diacritics (bool): Whether to use diacritics for edit distance. Yields: Iterator[Tuple[RegressionCase, FinalTarget]]: The generator of the regression case along with each of the final target sub-cases. """ total = self.estimate_total_distinct_cases() - for (regr_case, subcase) in tqdm.tqdm(self.get_all_distinct_cases(translation), + for (regr_case, subcase) in tqdm.tqdm(self.get_all_distinct_cases(translation, + edit_distance, + use_diacritics), total=total, disable=not show_progress): for target in subcase: yield regr_case, target - def check_model(self, cat: CAT, translation: TranslationLayer) -> MultiDescriptor: + def check_model(self, cat: CAT, translation: TranslationLayer, + edit_distance: Tuple[int, int, int] = (0, 0, 0), + use_diacritics: bool = False, + ) -> MultiDescriptor: """Checks model and generates a report Args: cat (CAT): The model to check against translation (TranslationLayer): The translation layer + edit_distance (Tuple[int, int, int]): The edit distance of the names. + Defaults to (0, 0, 0). + use_diacritics (bool): Whether to use diacritics for edit distance. Returns: MultiDescriptor: A report description """ - for regr_case, target in self.iter_subcases(translation, True): + for regr_case, target in self.iter_subcases(translation, True, + edit_distance, use_diacritics): # NOTE: the finding is reported in the per-case report regr_case.check_specific_for_phrase(cat, target, translation) return self.report diff --git a/medcat/utils/regression/regression_checker.py b/medcat/utils/regression/regression_checker.py index 5cb743734..4906e9db8 100644 --- a/medcat/utils/regression/regression_checker.py +++ b/medcat/utils/regression/regression_checker.py @@ -3,7 +3,7 @@ from pathlib import Path import logging -from typing import Optional +from typing import Optional, Tuple from medcat.cat import CAT from medcat.utils.regression.checking import RegressionSuite, TranslationLayer @@ -47,7 +47,8 @@ def main(model_pack_dir: Path, test_suite_file: Path, mct_export_yaml_path: Optional[str] = None, only_mct_export_conversion: bool = False, only_describe: bool = False, - require_fully_correct: bool = False) -> None: + require_fully_correct: bool = False, + edit_distance: Tuple[int, int, int] = (0, 0, 0)) -> None: """Check test suite against the specifeid model pack. Args: @@ -72,6 +73,11 @@ def main(model_pack_dir: Path, test_suite_file: Path, require_fully_correct (bool): Whether all cases are required to be correct. If set to True, an exit-status of 1 is returned unless all (sub)cases are correct. Defaults to False. + edit_distance (Tuple[int, int, int]): The edit distance, the random seed, and the number + of edited names to pick for each of the names. If set to non-0, the specified number + of splits, deletes, transposes, replaces, or inserts are done to the each name. This + can be useful for looking at the capability of identifying typos in text. However, + this can make hte process a lot slower as a resullt. Defaults to (0, 0, 0). Raises: ValueError: If unable to overwrite file or folder does not exist. @@ -101,7 +107,10 @@ def main(model_pack_dir: Path, test_suite_file: Path, logger.info('Loading model pack from file: %s', model_pack_dir) cat: CAT = CAT.load_model_pack(str(model_pack_dir)) logger.info('Checking the current status') - res = rc.check_model(cat, TranslationLayer.from_CDB(cat.cdb)) + res = rc.check_model(cat, TranslationLayer.from_CDB(cat.cdb), + edit_distance=edit_distance, + use_diacritics=cat.config.general.diacritics) + cat.config.general strictness = Strictness[strictness_str] if examples_strictness_str in ("None", "N/A"): examples_strictness = None @@ -123,6 +132,16 @@ def main(model_pack_dir: Path, test_suite_file: Path, exit(1) +def tuple3_parser(arg: str) -> Tuple[int, int, int]: + parts = arg.strip("()").split(',') + if len(parts) != 3: + raise argparse.ArgumentTypeError("Tuple must be in the form (x, y, z)") + try: + return (int(parts[0]), int(parts[1]), int(parts[2])) + except ValueError: + raise argparse.ArgumentTypeError("Tuple must be in the form (x, y, z)") + + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('modelpack', help='The model pack against which to check', @@ -166,9 +185,22 @@ def main(model_pack_dir: Path, test_suite_file: Path, parser.add_argument('--only-describe', help='Only describe the various findings and exit.', action='store_true') parser.add_argument('--require-fully-correct', help='Require the regression test to be fully correct. ' - 'If set, a non-zero exit status is returned unless all cases are successful (100%). ' + 'If set, a non-zero exit status is returned unless all cases are successful (100%%). ' 'This can be useful for (e.g) CI workflow integration.', action='store_true') + parser.add_argument('--edit-distance', help='Set the edit distance of each of the names. ' + 'If set, each name tested will have the specified number of characters changed. ' + 'This can be useful to determine the versatility of the model in terms of ' + 'recognising typos. Defauts to 0 (i.e no change). You need to provide 3 numbers ' + 'in the format `(N, R, P)` where `N` is the edit distance, `R` is the random seed, ' + 'and `P` is the number of choices to make.', + # ' NOTE: Edit distances greater ' + # 'than 1 will add an expenentially higher and higher number of sub-cases and thus ' + # 'time for the regression suite to be run. Even at edit distance 2 you can have ' + # '500 000 different variants for a 15 character long name and the longer the name, ' + # 'the more varions you get. For instance, a 76 characater long name could have ' + # 'upwards of 15 million varaints.', + type=tuple3_parser, default=(0, 0, 0)) args = parser.parse_args() if not args.silent: logger.addHandler(logging.StreamHandler()) @@ -183,4 +215,4 @@ def main(model_pack_dir: Path, test_suite_file: Path, strictness_str=args.strictness, max_phrase_length=args.max_phrase_length, use_mct_export=args.from_mct_export, mct_export_yaml_path=args.mct_export_yaml, only_mct_export_conversion=args.only_conversion, only_describe=args.only_describe, - require_fully_correct=args.require_fully_correct) + require_fully_correct=args.require_fully_correct, edit_distance=args.edit_distance) diff --git a/medcat/utils/regression/results.py b/medcat/utils/regression/results.py index 421ec217a..2667a970a 100644 --- a/medcat/utils/regression/results.py +++ b/medcat/utils/regression/results.py @@ -452,8 +452,13 @@ def iter_examples(self, strictness_threshold: Strictness Yields: Iterable[Tuple[FinalTarget, Tuple[Finding, Optional[str]]]]: The placeholder, phrase, finding, CUI, and name. """ - for srd in self.per_phrase_results.values(): - for target, finding in srd.examples: + phrases = sorted(self.per_phrase_results.keys()) + for phrase in phrases: + srd = self.per_phrase_results[phrase] + # sort by finding 1st, found CUI 2nd, and used name 3rd + sorted_examples = sorted( + srd.examples, key=lambda tf: (tf[1][0].name, str(tf[1][1]), tf[0].name)) + for target, finding in sorted_examples: if finding[0] not in STRICTNESS_MATRIX[strictness_threshold]: yield target, finding @@ -490,7 +495,8 @@ def dict(self, **kwargs) -> dict: # NOTE: need to propagate here manually so the strictness keyword # makes sense and doesn't cause issues due being to unexpected keyword per_phrase_results = { - phrase: res.dict(**kwargs) for phrase, res in self.per_phrase_results.items() + phrase: res.dict(**kwargs) for phrase, res in + sorted(self.per_phrase_results.items(), key=lambda it: it[0]) } d['per_phrase_results'] = per_phrase_results return d @@ -654,8 +660,8 @@ def get_report(self, phrases_separately: bool, if hide_empty: empty_text = f' A total of {nr_of_empty} cases did not match any CUIs and/or names.' ret_vals = [f"""A total of {len(self.parts)} parts were kept track of within the group "{self.name}". -And a total of {total_total} (sub)cases were checked.{empty_text}"""] - allowed_fingings_str = [f.name for f in allowed_findings] + And a total of {total_total} (sub)cases were checked.{empty_text}"""] + allowed_fingings_str = sorted([f.name for f in allowed_findings]) ret_vals.extend([ f"At the strictness level of {strictness} (allowing {allowed_fingings_str}):", f"The number of total successful (sub) cases: {total_s} " diff --git a/medcat/utils/regression/targeting.py b/medcat/utils/regression/targeting.py index 8acd12f3a..cc8e494f7 100644 --- a/medcat/utils/regression/targeting.py +++ b/medcat/utils/regression/targeting.py @@ -68,7 +68,11 @@ def get_names_of(self, cui: str, only_prefnames: bool) -> List[str]: if only_prefnames: return [self.get_preferred_name(cui).replace(self.separator, self.whitespace)] return [name.replace(self.separator, self.whitespace) - for name in self.cui2names.get(cui, [])] + # NOTE: sorting the order here in case we're using + # edirts in which case the order of the names + # needs to be the same, otherwise different + # edits will be used across runs + for name in sorted(self.cui2names.get(cui, []))] def get_preferred_name(self, cui: str) -> str: """Get the preferred name of a concept. diff --git a/medcat/utils/regression/utils.py b/medcat/utils/regression/utils.py index 3d630bec3..7b6a7607b 100644 --- a/medcat/utils/regression/utils.py +++ b/medcat/utils/regression/utils.py @@ -1,8 +1,11 @@ -from typing import Iterator, Tuple, List, Dict, Any, Type +from typing import Iterator, Tuple, List, Dict, Any, Type, Callable, Set import ast import inspect from enum import Enum +from functools import lru_cache +import random +import logging from medcat.stats.mctexport import MedCATTrainerExport, MedCATTrainerExportDocument @@ -223,3 +226,62 @@ def add_doc_strings_to_enum(cls: Type[Enum]) -> None: docstrings = docstrings[1:] for ev, ds in zip(cls, docstrings): ev.__doc__ = ds + + +@lru_cache(maxsize=10) +def get_rng(seed: int) -> random.Random: + return random.Random(seed) + + +# NOTE: these are 'relatively accurate' estimates +# that I obtained by running it on 15 different +# concepts with names varying from length +# of 5 to length of 74, a total of 316 names +ESTIMATION_MATRIX: Dict[int, Callable[[int], int]] = { + 1: lambda orig_len: int(52.23 * orig_len + 24.26), + 2: lambda orig_len: 2724 * orig_len**2 + 3917 * orig_len + 1098 +} + + +def estimate_num_variants(orig_len: int, edit_distance: int) -> int: + if edit_distance in ESTIMATION_MATRIX: + return ESTIMATION_MATRIX[edit_distance](orig_len) + logging.warning("Estimations for then umber of varinats for edit " + "distance greater than 2 (%d used) can be extremely " + "inaccurate.") + # NOTE: This is a low ball estimate - the real number could be a lot bigger + powers = list(range(0, edit_distance+1))[::-1] + estimate_coefs = [(2 * 26) ** ed for ed in powers] + estimated = 0 + for coef, power in zip(estimate_coefs, powers): + estimated += coef * orig_len ** power + return estimated + + +FAIL_AFTER_MULT = 10 + + +def pick_random_edits(edit_gen: Iterator[str], num_to_pick: int, + orig_len: int, edit_distance: int, rng_seed: int) -> Iterator[str]: + num_vars = estimate_num_variants(orig_len, edit_distance) + if num_to_pick > num_vars: + raise ValueError(f"Unable to ick {num_to_pick} out of {num_vars} " + f"(estimated from edit distance {edit_distance} " + f"and word length {orig_len})") + rng = get_rng(rng_seed) + pick_avoids = num_to_pick > num_vars // 2 + _num_to_pick = num_to_pick if not pick_avoids else num_vars - num_to_pick + pick_set: Set[int] = set() + while len(pick_set) < _num_to_pick: + pick_set.add(rng.randint(0, num_vars)) + if pick_avoids: + # NUMBERS NOT IN + picks = sorted(set(range(num_vars)) - pick_set) + else: + picks = sorted(list(pick_set)) + for enr, edit in enumerate(edit_gen): + if enr == picks[0]: + picks.pop(0) + yield edit + if not picks: + break diff --git a/medcat/utils/versioning.py b/medcat/utils/versioning.py deleted file mode 100644 index 005213421..000000000 --- a/medcat/utils/versioning.py +++ /dev/null @@ -1,438 +0,0 @@ -from typing import Tuple, List -import re -import os -import shutil -import argparse -import logging -from functools import partial - -import dill -import json - -from medcat.cat import CAT -from medcat.utils.decorators import deprecated -from medcat.utils.config_utils import default_weighted_average - -logger = logging.getLogger(__name__) - -SemanticVersion = Tuple[int, int, int] - - -# Regex as per: -# https://semver.org/#is-there-a-suggested-regular-expression-regex-to-check-a-semver-string -SEMANTIC_VERSION_REGEX = (r"^(0|[1-9]\d*)" # major - r"\.(0|[1-9]\d*)" # .minor - # CHANGE FROM NORM - allowing dev before patch version number - # but NOT capturing the group - r"\.(?:dev)?" - r"(0|[1-9]\d*)" # .patch - # and then some trailing stuff - r"(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" - r"(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$") -SEMANTIC_VERSION_PATTERN = re.compile(SEMANTIC_VERSION_REGEX) - - -CDB_FILE_NAME = "cdb.dat" - - -def get_semantic_version(version: str) -> SemanticVersion: - """Get the semantiv version from the string. - - Args: - version (str): The version string. - - Raises: - ValueError: If the version string does not match the semantic versioning format. - - Returns: - SemanticVersion: The major, minor and patch version - """ - match = SEMANTIC_VERSION_PATTERN.match(version) - if not match: - raise ValueError(f"Unknown version string: {version}") - return int(match.group(1)), int(match.group(2)), int(match.group(3)) - - -def get_version_from_modelcard(d: dict) -> SemanticVersion: - """Gets the the major.minor.patch version from a model card. - - The version needs to be specified at: - model_card["MedCAT Version"] - The version is expected to be semantic (major.minor.patch). - - Args: - d (dict): The model card in dict format. - - Returns: - SemanticVersion: The major, minor and patch version - """ - version_str: str = d["MedCAT Version"] - return get_semantic_version(version_str) - - -def get_semantic_version_from_model(cat: CAT) -> SemanticVersion: - """Get the semantic version of a CAT model. - - This uses the `get_version_from_modelcard` method on the model's - model card. - - So it is equivalen to `get_version_from_modelcard(cat.get_model_card(as_dict=True))`. - - Args: - cat (CAT): The CAT model. - - Returns: - SemanticVersion: The major, minor and patch version - """ - return get_version_from_modelcard(cat.get_model_card(as_dict=True)) - - -def get_version_from_cdb_dump(cdb_path: str) -> SemanticVersion: - """Get the version from a CDB dump (cdb.dat). - - The version information is expected in the following location: - cdb["config"]["version"]["medcat_version"] - - Args: - cdb_path (str): The path to cdb.dat - - Returns: - SemanticVersion: The major, minor and patch version - """ - with open(cdb_path, 'rb') as f: - d = dill.load(f) - config: dict = d["config"] - version = config["version"]["medcat_version"] - return get_semantic_version(version) - - -def get_version_from_modelpack_zip(zip_path: str, cdb_file_name=CDB_FILE_NAME) -> SemanticVersion: - """Get the semantic version from a MedCAT model pack zip file. - - This involves simply reading the config on file and reading the version information from there. - - The zip file is extracted if it has not yet been extracted. - - Args: - zip_path (str): The zip file path for the model pack. - cdb_file_name (str, optional): The CDB file name to use. Defaults to "cdb.dat". - - Returns: - SemanticVersion: The major, minor and patch version - """ - model_pack_path = CAT.attempt_unpack(zip_path) - return get_version_from_cdb_dump(os.path.join(model_pack_path, cdb_file_name)) - - -UPDATE_VERSION = (1, 3, 0) - - -class ConfigUpgrader: - """Config updater. - - Attempts to upgrade pre 1.3.0 medcat configs to the newer format. - - Args: - zip_path (str): The model pack zip path. - cdb_file_name (str, optional): The CDB file name. Defaults to "cdb.dat". - """ - - def __init__(self, zip_path: str, cdb_file_name: str = CDB_FILE_NAME) -> None: - self.model_pack_path = CAT.attempt_unpack(zip_path) - self.cdb_path = os.path.join(self.model_pack_path, cdb_file_name) - self.current_version = get_version_from_cdb_dump(self.cdb_path) - logger.debug("Loaded model from %s at version %s", - self.model_pack_path, self.current_version) - - def needs_upgrade(self) -> bool: - """Check if the specified modelpack needs an upgrade. - - It needs an upgrade if its version is less than 1.3.0. - - Returns: - bool: Whether or not an upgrade is needed. - """ - return self.current_version < UPDATE_VERSION - - def _get_relevant_files(self, ignore_hidden: bool = True) -> List[str]: - """Get the list of relevant files with full path names. - - By default this will ignore hidden files (those that start with '.'). - - Args: - ignore_hidden (bool): Whether to ignore hidden files. Defaults to True. - - Returns: - List[str]: The list of relevant file names to copy. - """ - return [os.path.join(self.model_pack_path, fn) # ignores hidden files - for fn in os.listdir(self.model_pack_path) if (ignore_hidden and not fn.startswith("."))] - - def _check_existence(self, files_to_copy: List[str], new_path: str, overwrite: bool): - if overwrite: - return # ignore all - if not os.path.exists(new_path): - os.makedirs(new_path) - return # all good, new folder - # check file existence in new (existing) path - for file_to_copy in files_to_copy: - new_file_name = os.path.join( - new_path, os.path.basename(file_to_copy)) - if os.path.exists(new_file_name): - raise ValueError(f"File already exists: {new_file_name}. " - "Pass overwrite=True to overwrite") - - def _copy_files(self, files_to_copy: List[str], new_path: str) -> None: - for file_to_copy in files_to_copy: - new_file_name = os.path.join( - new_path, os.path.basename(file_to_copy)) - if os.path.isdir(file_to_copy): - # if exists is OK since it should have been checked before - # if it was not to be overwritten - logger.debug("Copying folder %s to %s", - file_to_copy, new_file_name) - shutil.copytree(file_to_copy, new_file_name, - dirs_exist_ok=True) - else: - logger.debug("Copying file %s to %s", - file_to_copy, new_file_name) - shutil.copy(file_to_copy, new_file_name) - - def upgrade(self, new_path: str, overwrite: bool = False) -> None: - """Upgrade the model. - - The upgrade copies all the files from the original folder - to the new folder. - - After copying, it changes the config into the format - required by MedCAT after version 1.3.0. - - Args: - new_path (str): The path for the new model pack folder. - overwrite (bool): Whether to overwrite new path. Defaults to False. - - Raises: - ValueError: If one of the target files exists and cannot be overwritten. - IncorrectModel: If model pack does not need an upgrade - """ - if not self.needs_upgrade(): - raise IncorrectModel(f"Model pack does not need upgrade: {self.model_pack_path} " - f"since it's at version: {self.current_version}") - logger.info("Starting to upgrade %s at (version %s)", - self.model_pack_path, self.current_version) - files_to_copy = self._get_relevant_files() - try: - self._check_existence(files_to_copy, new_path, overwrite) - except ValueError as e: - raise e - logger.debug("Copying files from %s", self.model_pack_path) - self._copy_files(files_to_copy, new_path) - logger.info("Going to try and fix CDB") - self._fix_cdb(new_path) - self._make_archive(new_path) - - def _fix_cdb(self, new_path: str) -> None: - new_cdb_path = os.path.join(new_path, os.path.basename(self.cdb_path)) - with open(new_cdb_path, 'rb') as f: - data = dill.load(f) - # make the changes - - logger.debug("Fixing CDB issue #1 (linking.filters.cui)") - # Number 1 - # the linking.filters.cuis is set to "{}" - # which is assumed to be an empty set, but actually - # evaluates to an empty dict instead - cuis = data['config']['linking']['filters']['cuis'] - if cuis == {}: - # though it _should_ be the empty set - data['config']['linking']['filters']['cuis'] = set(cuis) - # save modified version - logger.debug("Saving CDB back into %s", new_cdb_path) - with open(new_cdb_path, 'wb') as f: - dill.dump(data, f) - - def _make_archive(self, new_path: str): - logger.debug("Taking data from %s and writing it to %s.zip", - new_path, new_path) - shutil.make_archive( - base_name=new_path, format='zip', base_dir=new_path) - - -def parse_args() -> argparse.Namespace: - """Parse the arguments from the CLI. - - Returns: - argparse.Namespace: The parsed arguments. - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "action", help="The action. Currently, only 'fix-config' or 'allow-pre-1.12' are available.", - choices=['fix-config', 'allow-pre-1.12'], type=str.lower) - parser.add_argument("modelpack", help="MedCAT modelpack zip path") - parser.add_argument("newpath", help="The path for the new modelpack") - parser.add_argument( - "--overwrite", help="Allow overvwriting existing files", action="store_true") - parser.add_argument( - "--silent", help="Disable logging", action="store_true") - parser.add_argument( - "--verbose", help="Show debug output", action="store_true") - return parser.parse_args() - - -def setup_logging(args: argparse.Namespace) -> None: - """Setup logging for the runnable based on CLI arguments. - - Args: - args (argparse.Namespace): The parsed arguments. - """ - if not args.silent: - logger.addHandler(logging.StreamHandler()) - if args.verbose: - logger.setLevel(logging.DEBUG) - - -@deprecated("This is no longer needed. Since medcat 1.10 (PR #352) " - "this dealt with automatically upon model load.", - depr_version=(1, 10, 0), removal_version=(1, 14, 0)) -def fix_config(args: argparse.Namespace) -> None: - """Perform the fix-config action based on the CLI arguments. - - Args: - args (argparse.Namespace): The parsed arguments. - """ - logger.debug("Setting up upgrader") - upgrader = ConfigUpgrader(args.modelpack) - logger.debug("Starting the upgrade process") - upgrader.upgrade(args.newpath, overwrite=args.overwrite) - - -def _do_pre_1_12_fix(model_pack_path: str) -> CAT: - cat = CAT.load_model_pack(model_pack_path) - waf = cat.cdb.weighted_average_function - is_def = waf is default_weighted_average - is_partial = (isinstance(waf, partial) - and waf.func is default_weighted_average) - if is_def: - factor = 0.0004 - logger.info("Was using default weighted average") - elif is_partial: - pargs = waf.args - pkwargs = waf.keywords - factor = pargs[0] if pargs else pkwargs['factor'] - logger.info("Was using a (near) default weighted average") - else: - raise IncorrectModel("Model does not have fixable weighted_average tied to its CDB, " - f"found: {waf}") - cat.cdb.weighted_average_function = lambda step: max(0.1, 1 - (step ** 2 * factor)) - return cat - - -def _set_change(val: dict): - return {"py/set": val["==SET=="]} - - -def _pattern_change(val: dict): - return { - "py/object": "re.Pattern", - "pattern": val["==PATTERN=="] - } - - -TO_CHANGE = { - "preprocessing.words_to_skip": _set_change, - "preprocessing.keep_punct": _set_change, - "preprocessing.do_not_normalize": _set_change, - "linking.filters.cuis": _set_change, - "linking.filters.cuis_exclude": _set_change, - "word_skipper": _pattern_change, - "punct_checker": _pattern_change, -} - - -def _fix_config_for_pre_1_12(folder: str): - config_path = os.path.join(folder, 'config.json') - with open(config_path) as f: - data = json.load(f) - for fix_path, fixer in TO_CHANGE.items(): - logger.info("[Pre 1.12 fix] Changing %s", fix_path) - cur_path = fix_path - last_dict = data - while "." in cur_path: - cur_key, cur_path = cur_path.split(".", 1) - last_dict = last_dict[cur_key] - last_key = cur_path - last_value = last_dict[last_key] - last_dict[last_key] = fixer(last_value) - logger.info("[Pre 1.12 fix] Saving config back to %s", config_path) - with open(config_path, 'w') as f: - json.dump(data, f) - logger.info("[Pre 1.12 fix] Recreating archive for %s", folder) - shutil.make_archive(folder, 'zip', root_dir=folder) - - -@deprecated("This is only really needed for 1.12+ models " - "to be converted to lower versions of medcat. " - "It should not be needed in the long run.", - depr_version=(1, 13, 0), removal_version=(1, 14, 0)) -def allow_loading_with_pre_1_12(args: argparse.Namespace): - """This method converts a model created after medcat 1.12 - such that it can be loaded in previous versions. - - The main two things it does: - - Simplifies the weighted average function attached to the CDB. - - Makes the config json-compatible - - Expected / used arguments in CLI: - - modelpack: The input model pack path - - newpath: The output model pack path - - overwrite: Whether to overwrite the new model - - Raises: - ValueError: If the file already exists - - Args: - args (argparse.Namespace): The CLI arguments. - """ - # this will fix the weighted_average function if possible - # since 1.12 this is within the CDB and generally refers - # to a method on medcat.utils.config_utils and the method - # and/or the module do not exist in previous version - cat = _do_pre_1_12_fix(args.modelpack) - if not args.overwrite and os.path.exists(args.newpath): - raise ValueError(f"File already exists: {args.newpath}. " - "Set --overwrite to overwrite") - mpn = cat.create_model_pack(args.newpath) - full_path = os.path.join(args.newpath, mpn) - logger.info("Saving model to: %s", full_path) - # now that the model has saved, we also need to do make - # some changes to the config to allow it to be properly - # loaded by jsonpickle (used before 1.12) rather than - # just json (used by 1.12+) - _fix_config_for_pre_1_12(full_path) - - -class IncorrectModel(ValueError): - - def __init__(self, *args: object) -> None: - super().__init__(*args) - - -def main() -> None: - """Run the CLI associated with this module. - - Raises: - ValueError: If an unknown action is provided. - """ - args = parse_args() - setup_logging(args) - logger.debug("Will attempt to perform action %s", args.action) - if args.action == 'fix-config': - fix_config(args) - elif args.action == 'allow-pre-1.12': - allow_loading_with_pre_1_12(args) - else: - raise ValueError(f"Unknown action: {args.action}") - - -if __name__ == "__main__": - main() diff --git a/requirements-dev.txt b/requirements-dev.txt index 6b458abf0..6b954afc9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,9 +2,9 @@ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.6.0/en_core_web_md-3.6.0-py3-none-any.whl flake8~=7.0.0 darglint~=1.8.1 -mypy>=1.7.0,<2.0.0 +mypy>=1.7.0,<1.12.0 mypy-extensions>=1.0.0 types-aiofiles==0.8.3 types-PyYAML==6.0.3 types-setuptools==57.4.10 -timeout-decorator==0.5.0 +timeout-decorator==0.5.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 549e7c091..08440b9ec 100644 --- a/setup.py +++ b/setup.py @@ -27,12 +27,12 @@ packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets', 'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction', 'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'], + python_requires='>=3.9', # 3.8 is EoL install_requires=install_requires, include_package_data=True, package_data={"medcat": ["install_requires.txt"]}, classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/resources/model_compatibility/check_backwards_compatibility.sh b/tests/resources/model_compatibility/check_backwards_compatibility.sh new file mode 100644 index 000000000..5e3fd2ae1 --- /dev/null +++ b/tests/resources/model_compatibility/check_backwards_compatibility.sh @@ -0,0 +1,43 @@ +# CONSTANTs/ shouldn't change +REGRESSION_MODULE="medcat.utils.regression.regression_checker" +REGRESSION_OPTIONS="--strictness STRICTEST --require-fully-correct" + +# CHANGABLES +# target models +DL_LINK="https://cogstack-medcat-example-models.s3.eu-west-2.amazonaws.com/medcat-example-models/all_fake_medcat_models.zip" +ZIP_FILE_NAME="all_fake_medcat_models.zip" +# target regression set +REGRESSION_TEST_SET="tests/resources/regression/testing/test_model_regresssion.yml" +# folder to house models under test +MODEL_FOLDER="fake_models" + +# START WORK + +echo "Downloading models" +wget $DL_LINK +# Create folder if it doesn't exit +mkdir -p "$MODEL_FOLDER" +echo "Uncompressing files" +unzip $ZIP_FILE_NAME -d $MODEL_FOLDER +echo "Cleaning up the overall zip" +rm $ZIP_FILE_NAME +for model_path in `ls $MODEL_FOLDER/*.zip`; do + if [ -f "$model_path" ]; then + echo "Processing $model_path" + python -m $REGRESSION_MODULE \ + "$model_path" \ + $REGRESSION_TEST_SET \ + $REGRESSION_OPTIONS + # this is a sanity check - needst to run after so that the folder has been created + grep "MedCAT Version" "${model_path%.*}/model_card.json" + # clean up here so we don't leave both the .zip'ed model + # and the folder so we don't fill the disk + echo "Cleaning up at: ${model_path%.*}" + rm -rf ${model_path%.*}* + else + echo "No files found matching the pattern: $file" + fi +done + +# Remove the fake model folder +rm -r "$MODEL_FOLDER" diff --git a/tests/test_cat.py b/tests/test_cat.py index 4c237f58c..17cdd2819 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -2,6 +2,8 @@ import os import sys import time +from typing import Callable +from functools import partial import unittest from unittest.mock import mock_open, patch import tempfile @@ -595,18 +597,55 @@ def test_get_entities_gets_monitored(self, contents = f.readline() self.assertTrue(contents) + def assert_gets_usage_monitored(self, data_processor: Callable[[None], None], exp_logs: int = 1): + # clear usage monitor buffer + self.undertest.usage_monitor.log_buffer.clear() + data_processor() + file = self.undertest.usage_monitor.log_file + if os.path.exists(file): + with open(file) as f: + content = f.readlines() + content += self.undertest.usage_monitor.log_buffer + else: + content = self.undertest.usage_monitor.log_buffer + self.assertTrue(content) + self.assertEqual(len(content), exp_logs) + def test_get_entities_logs_usage(self, text="The dog is sitting outside the house."): # clear usage monitor buffer - self.undertest.usage_monitor.log_buffer.clear() - self.undertest.get_entities(text) - self.assertTrue(self.undertest.usage_monitor.log_buffer) - self.assertEqual(len(self.undertest.usage_monitor.log_buffer), 1) + self.assert_gets_usage_monitored(partial(self.undertest.get_entities, text), 1) line = self.undertest.usage_monitor.log_buffer[0] # the 1st element is the input text length input_text_length = line.split(",")[1] self.assertEqual(str(len(text)), input_text_length) + TEXT4MP_USAGE = [ + ("ID1", "Text with house and dog one"), + ("ID2", "Text with house and dog two"), + ("ID3", "Text with house and dog three"), + ("ID4", "Text with house and dog four"), + ("ID5", "Text with house and dog five"), + ("ID6", "Text with house and dog siz"), + ("ID7", "Text with house and dog seven"), + ("ID8", "Text with house and dog eight"), + ] + + def test_mp_batch_char_size_logs_usage(self): + all_text = self.TEXT4MP_USAGE + proc = partial(self.undertest.multiprocessing_batch_char_size, all_text, nproc=2) + self.assert_gets_usage_monitored(proc, len(all_text)) + + def test_mp_get_multi_texts_logs_usage(self): + all_text = self.TEXT4MP_USAGE + proc = partial(self.undertest.get_entities_multi_texts, all_text, n_process=2) + self.assert_gets_usage_monitored(proc, len(all_text)) + + def test_mp_batch_docs_size_logs_usage(self): + all_text = self.TEXT4MP_USAGE + proc = partial(self.undertest.multiprocessing_batch_docs_size, all_text, nproc=2) + self.assert_gets_usage_monitored(proc, len(all_text)) + def test_simple_hashing_is_faster(self): self.undertest.config.general.simple_hash = False st = time.perf_counter() diff --git a/tests/utils/regression/test_utils.py b/tests/utils/regression/test_utils.py index fa50af074..93f621bc5 100644 --- a/tests/utils/regression/test_utils.py +++ b/tests/utils/regression/test_utils.py @@ -8,6 +8,8 @@ from medcat.utils.regression import utils from medcat.utils.regression.checking import RegressionSuite +from medcat.utils.normalizers import get_all_edits_n + class PartialSubstituationTests(TestCase): TEXT1 = "This [PH1] has one placeholder" @@ -181,3 +183,47 @@ def test_unchanged_has_class_doc_Strings(self): for ec in MyE3: with self.subTest(str(ec)): self.assertEqual(ec.__doc__, MyE3.__doc__) + + +class EditBaseTests(TestCase): + WORDS = ['WORDs', 'multi word', 'long-ass-word', + 'complexsuperlongwordthatexists'] + + +class EditEstimationTests(EditBaseTests): + + def assert_can_estimate_dist(self, edit_distance: int, tol_perc: float): + for word in self.WORDS: + with self.subTest(word): + got = len(list(get_all_edits_n(word, False, edit_distance))) + expected = utils.estimate_num_variants(len(word), edit_distance) + ratio = got / expected + self.assertTrue(1 - tol_perc < ratio < 1 + tol_perc, + f"Ratio {ratio} vs TOL {tol_perc}") + + def test_can_estimate_dist1(self): + self.assert_can_estimate_dist(1, 0.04) + + def test_can_estimate_dist2(self): + self.assert_can_estimate_dist(2, 0.06) + + +class EditTests(EditBaseTests): + ORIG_WORD = "WORD" + EDIT_DIST = 1 + ALL_EDITS = list(get_all_edits_n(ORIG_WORD, False, EDIT_DIST)) + LEN = len(ALL_EDITS) + # NOTE: can't use the full length since the estimation is lower + PICKS = [1, 5, 10, LEN - 20] + RNG_SEED = 42 + + def test_can_pick_correct_number(self): + for pick in self.PICKS: + with self.subTest(f"Pick {pick}"): + picked = list(utils.pick_random_edits( + self.ALL_EDITS, edit_distance=self.EDIT_DIST, + num_to_pick=pick, orig_len=len(self.ORIG_WORD), + rng_seed=self.RNG_SEED)) + self.assertEqual(len(picked), pick) + # make sure the names are unique + self.assertEqual(len(picked), len(set(picked))) diff --git a/tests/utils/test_normalizers.py b/tests/utils/test_normalizers.py new file mode 100644 index 000000000..cb0140c7a --- /dev/null +++ b/tests/utils/test_normalizers.py @@ -0,0 +1,71 @@ +import unittest + +from medcat.utils import normalizers + + +class EditOrderTests(unittest.TestCase): + WORD = "abc" + EXMAPLE_EDITS_ORDER = [ + 'abqc', 'rbc', 'obc', 'fbc', 'abyc', + 'azbc', 'ibc', 'xbc', 'apc', 'abcl', + 'abcr', 'abck', 'anc', 'abd', 'abkc', + 'iabc', 'tbc', 'cabc', 'abw', 'abp', + 'abe', 'akbc', 'apbc', 'hbc', 'ubc', + 'abic', 'babc', 'abcq', 'wabc', 'abtc', + 'aibc', 'yabc', 'asc', 'abrc', 'avbc', + 'abu', 'kabc', 'axc', 'fabc', 'nbc', + 'rabc', 'abec', 'abcu', 'gbc', 'amc', + 'abce', 'abdc', 'abcy', 'bbc', 'dbc', + 'abac', 'abvc', 'abuc', 'avc', 'abi', + 'abm', 'abjc', 'abcp', 'tabc', 'cbc', + 'uabc', 'abz', 'aby', 'qbc', 'abcf', + 'abpc', 'axbc', 'abk', 'gabc', 'abc', + 'mbc', 'aqbc', 'abci', 'oabc', 'qabc', + 'abf', 'vabc', 'abj', 'abbc', 'aubc', + 'acbc', 'abn', 'aebc', 'ebc', 'abfc', + 'dabc', 'abh', 'arc', 'aqc', 'albc', + 'aac', 'abcb', 'sabc', 'ybc', 'abcv', + 'absc', 'abca', 'labc', 'ajbc', 'kbc', + 'pabc', 'abcc', 'afbc', 'sbc', 'abl', + 'awc', 'ahbc', 'abco', 'anbc', 'abo', + 'abg', 'abcn', 'awbc', 'adc', 'ahc', + 'habc', 'abb', 'vbc', 'aboc', 'abq', + 'acc', 'agc', 'abcx', 'nabc', 'abwc', + 'lbc', 'abcm', 'afc', 'ab', 'atc', + 'aybc', 'akc', 'abt', 'aic', 'jbc', + 'aec', 'zabc', 'agbc', 'abv', 'abnc', + 'abcj', 'pbc', 'abcg', 'bac', 'abr', + 'aobc', 'abcd', 'alc', 'aoc', 'ajc', + 'abx', 'arbc', 'ayc', 'aba', 'abcw', + 'eabc', 'abcs', 'abhc', 'adbc', 'abgc', + 'asbc', 'acb', 'abs', 'aabc', 'abzc', + 'abxc', 'atbc', 'ambc', 'jabc', 'bc', + 'wbc', 'abcz', 'ablc', 'ac', 'azc', + 'abct', 'abmc', 'zbc', 'abch', 'auc', + 'xabc', 'mabc' + ] + + # NOTE: The there is a chance that this test fails. But it should be 2 in 182! + # (since I'm checking against 2 different orders - the one captured above + # and the alphabetically ordered version calculated on the fly). This is + # essentially 0 and _should_ never happen. + def test_order_not_guaranteed1(self): + all_edits = list(normalizers.get_all_edits_n(self.WORD, use_diacritics=False, n=1, return_ordered=False)) + self.assertNotEqual(all_edits, self.EXMAPLE_EDITS_ORDER) + self.assertNotEqual(all_edits, sorted(all_edits)) + + def test_ordered_within_same_run1(self): + all_edits1 = list(normalizers.get_all_edits_n(self.WORD, use_diacritics=False, n=1, return_ordered=False)) + all_edits2 = list(normalizers.get_all_edits_n(self.WORD, use_diacritics=False, n=1, return_ordered=False)) + self.assertEqual(all_edits1, all_edits2) + + def test_all_items_same_now1(self): + all_edits = list(normalizers.get_all_edits_n(self.WORD, use_diacritics=False, n=1, return_ordered=False)) + for got_now in all_edits: + with self.subTest(got_now): + self.assertIn(got_now, self.EXMAPLE_EDITS_ORDER) + + def test_can_guarantee_order1(self): + all_edits1 = list(normalizers.get_all_edits_n(self.WORD, use_diacritics=False, n=1, return_ordered=True)) + ordered = sorted(all_edits1) + self.assertEqual(all_edits1, ordered) diff --git a/tests/utils/test_preprocess_snomed.py b/tests/utils/test_preprocess_snomed.py index 59a00f6fc..d7c2f6629 100644 --- a/tests/utils/test_preprocess_snomed.py +++ b/tests/utils/test_preprocess_snomed.py @@ -1,7 +1,11 @@ +import os from typing import Dict +import contextlib + from medcat.utils import preprocess_snomed import unittest +from unittest.mock import patch EXAMPLE_REFSET_DICT: Dict = { @@ -45,20 +49,130 @@ def test_example_no_codfe_fails(self): with self.assertRaises(KeyError): preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_NO_CODE) + EXAMPLE_SNOMED_PATH_OLD = "SnomedCT_InternationalRF2_PRODUCTION_20220831T120000Z" +EXAMPLE_SNOMED_PATH_OLD_UK = "SnomedCT_UKClinicalRF2_PRODUCTION_20220831T120000Z" EXAMPLE_SNOMED_PATH_NEW = "SnomedCT_UKClinicalRF2_PRODUCTION_20231122T000001Z" -class TestSnomedVersionsOPCS4(unittest.TestCase): +@contextlib.contextmanager +def patch_fake_files(path: str, subfiles: list = [], + subdirs: list = ["Snapshot"]): + def cur_listdir(file_path: str, *args, **kwargs) -> list: + if file_path == path: + return subfiles + subdirs + for sd in subdirs: + subdir = os.path.join(path, sd) + if subdir == path: + return [] + raise FileNotFoundError(path) - def test_old_gets_old_OPCS4_mapping_nonuk_ext(self): - snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD, uk_ext=False) - self.assertEqual(snomed.opcs_refset_id, "1126441000000105") + def cur_isfile(file_path: str, *args, **kwargs) -> bool: + print("CUR isfile", file_path) + return file_path == path or file_path in [os.path.join(path, subfiles)] + + def cur_isdir(file_path: str, *args, **kwrags) -> bool: + print("CUR isdir", file_path) + return file_path == path or file_path in [os.path.join(path, subdirs)] + + with patch("os.listdir", new=cur_listdir): + with patch("os.path.isfile", new=cur_isfile): + with patch("os.path.isdir", new=cur_isdir): + yield - def test_old_gets_old_OPCS4_mapping_uk_ext(self): - snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD, uk_ext=True) + +class TestSnomedVersionsOPCS4(unittest.TestCase): + + def test_old_gets_old_OPCS4_mapping(self): + with patch_fake_files(EXAMPLE_SNOMED_PATH_OLD): + snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD) + snomed._set_extension(snomed._determine_release(EXAMPLE_SNOMED_PATH_OLD), + snomed._determine_extension(EXAMPLE_SNOMED_PATH_OLD)) + self.assertEqual(snomed.opcs_refset_id, "1382401000000109") # defaults to this now + + def test_old_gets_old_OPCS4_mapping_UK(self): + with patch_fake_files(EXAMPLE_SNOMED_PATH_OLD_UK): + snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD_UK) + snomed._set_extension(snomed._determine_release(EXAMPLE_SNOMED_PATH_OLD_UK), + snomed._determine_extension(EXAMPLE_SNOMED_PATH_OLD_UK)) self.assertEqual(snomed.opcs_refset_id, "1126441000000105") - def test_new_gets_new_OCPS4_mapping_uk_ext(self): - snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_NEW, uk_ext=True) + def test_new_gets_new_OCPS4_mapping(self): + with patch_fake_files(EXAMPLE_SNOMED_PATH_NEW): + snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_NEW) + snomed._set_extension(snomed._determine_release(EXAMPLE_SNOMED_PATH_NEW), + snomed._determine_extension(EXAMPLE_SNOMED_PATH_NEW)) self.assertEqual(snomed.opcs_refset_id, "1382401000000109") + + +class TestSnomedModelGetter(unittest.TestCase): + WORKING_BASE_NAMES = [ + "SnomedCT_InternationalRF2_PRODUCTION_20240201T120000Z", + "SnomedCT_InternationalRF2_PRODUCTION_20240601T120000Z", + "SnomedCT_UKClinicalRF2_PRODUCTION_20240410T000001Z", + "SnomedCT_UKClinicalRefsetsRF2_PRODUCTION_20240410T000001Z", + "SnomedCT_UKDrugRF2_PRODUCTION_20240508T000001Z", + "SnomedCT_UKEditionRF2_PRODUCTION_20240410T000001Z", + "SnomedCT_UKEditionRF2_PRODUCTION_20240508T000001Z", + "SnomedCT_Release_AU1000036_20240630T120000Z", + ] + FAILING_BASE_NAMES = [ + "uk_sct2cl_38.2.0_20240605000001Z", + "uk_sct2cl_32.6.0_20211027000001Z", + ] + PATH = os.path.join("path", "to", "release") + + def _pathify(self, in_list: list) -> list: + return [os.path.join(self.PATH, folder) for folder in in_list] + + def assert_got_version(self, snomed: preprocess_snomed.Snomed, raw_name: str): + rel_list = snomed.snomed_releases + self.assertIsInstance(rel_list, list) + self.assertEqual(len(rel_list), 1) + rel = rel_list[0] + self.assertIsInstance(rel, str) + self.assertIn(rel, raw_name) + self.assertEqual(rel, raw_name[-16:-8]) + + def assert_all_work(self, all_paths: list): + for path in all_paths: + with self.subTest(f"Rrelease name: {path}"): + with patch_fake_files(path): + snomed = preprocess_snomed.Snomed(path) + self.assert_got_version(snomed, path) + + def test_gets_model_form_basename(self): + self.assert_all_work(self.WORKING_BASE_NAMES) + + def test_gets_model_from_path(self): + full_paths = self._pathify(self.WORKING_BASE_NAMES) + self.assert_all_work(full_paths) + + def assert_raises(self, folder_path: str): + with self.assertRaises(preprocess_snomed.UnkownSnomedReleaseException): + preprocess_snomed.Snomed._determine_release(folder_path, strict=True) + + def assert_all_raise(self, folder_paths: list): + for folder_path in folder_paths: + with self.subTest(f"Folder: {folder_path}"): + self.assert_raises(folder_path) + + def test_fails_on_incorrect_names_strict(self): + self.assert_all_raise(self.FAILING_BASE_NAMES) + + def test_fails_on_incorrect_paths_strict(self): + full_paths = self._pathify(self.FAILING_BASE_NAMES) + self.assert_all_raise(full_paths) + + def assert_all_get_no_version(self, folder_paths: list): + for folder_path in folder_paths: + with self.subTest(f"Folder: {folder_path}"): + det_rel = preprocess_snomed.Snomed._determine_release(folder_path, strict=False) + self.assertEqual(det_rel, preprocess_snomed.Snomed.NO_VERSION_DETECTED) + + def test_gets_no_version_incorrect_names_nonstrict(self): + self.assert_all_get_no_version(self.FAILING_BASE_NAMES) + + def test_gets_no_version_incorrect_paths_nonstrict(self): + full_paths = self._pathify(self.FAILING_BASE_NAMES) + self.assert_all_get_no_version(full_paths) diff --git a/tests/utils/test_usage_monitoring.py b/tests/utils/test_usage_monitoring.py index 936cde37f..b47345bf3 100644 --- a/tests/utils/test_usage_monitoring.py +++ b/tests/utils/test_usage_monitoring.py @@ -89,7 +89,7 @@ def test_some_in_file(self): self.assertEqual(len(lines), self.expected_in_file) -class UMT(UsageMonitorBaseTests): +class UsageMonitoringAutoTests(UsageMonitorBaseTests): ENABLED_DICT = { "MEDCAT_USAGE_LOGS": "True", "MEDCAT_USAGE_LOGS_LOCATION": "." diff --git a/tests/utils/test_versioning.py b/tests/utils/test_versioning.py deleted file mode 100644 index 30a3afdd3..000000000 --- a/tests/utils/test_versioning.py +++ /dev/null @@ -1,163 +0,0 @@ -import unittest -import os -import tempfile -import shutil - -import dill -import pydantic - -from medcat.utils.versioning import get_version_from_modelcard, get_semantic_version_from_model -from medcat.utils.versioning import get_version_from_cdb_dump, get_version_from_modelpack_zip -from medcat.utils.versioning import ConfigUpgrader -from medcat.cat import CAT -from medcat.cdb import CDB -from medcat.vocab import Vocab - -from .regression.test_metadata import MODEL_CARD_EXAMPLE, EXAMPLE_VERSION - - -CORRECT_SEMANTIC_VERSIONS = [("1.0.1-alpha-1", (1, 0, 1)), ("0.0.1-alpha-1", (0, 0, 1)), - ("1.0.0-alpha.1", (1, 0, 0) - ), ("1.0.0-0.3.7", (1, 0, 0)), - ("1.0.0-x.7.z.92", (1, 0, 0) - ), ("1.0.0-x-y-z.--", (1, 0, 0)), - ("1.0.0-alpha+001", (1, 0, 0) - ), ("1.0.0+20130313144700", (1, 0, 0)), - ("1.0.0-beta+exp.sha.5114f85", (1, 0, 0)), - ("1.0.0+21AF26D3----117B344092BD", (1, 0, 0))] -INCORRECT_SEMANTIC_VERSIONS = ["01.0.0", "0.01.0", "0.0.01", "0.0.0\nSOMETHING", - "1.0.space", "1.0.0- space"] - - -class VersionGettingFromModelCardTests(unittest.TestCase): - FAKE_MODEL_CARD1 = {"Something": "value"} - FAKE_MODEL_CARD2 = {"MedCAT Version": "not semantic"} - FAKE_MODEL_CARD3 = {"MedCAT Version": "almost.semantic"} - FAKE_MODEL_CARD4 = {"MedCAT Version": "closest.to.semantic"} - WRONG_VERSION_FAKE_MODELS = [FAKE_MODEL_CARD2, - FAKE_MODEL_CARD3, FAKE_MODEL_CARD4] - - def test_gets_correct_version(self): - maj, minor, patch = get_version_from_modelcard(MODEL_CARD_EXAMPLE) - self.assertEqual(EXAMPLE_VERSION, (maj, minor, patch)) - - def test_fails_upon_model_card_with_no_version_defined(self): - with self.assertRaises(KeyError): - get_version_from_modelcard(self.FAKE_MODEL_CARD1) - - def test_fails_upon_model_card_with_incorrect_version(self): - cntr = 0 - for fake_model_card in self.WRONG_VERSION_FAKE_MODELS: - with self.assertRaises(ValueError): - get_version_from_modelcard(fake_model_card) - cntr += 1 - self.assertEqual(cntr, len(self.WRONG_VERSION_FAKE_MODELS)) - - def test_fails_upon_wrong_version(self): - cntr = 0 - for wrong_version in INCORRECT_SEMANTIC_VERSIONS: - d = {"MedCAT Version": wrong_version} - with self.subTest(f"With version: {wrong_version}"): - with self.assertRaises(ValueError): - get_version_from_modelcard(d) - cntr += 1 - self.assertEqual(cntr, len(INCORRECT_SEMANTIC_VERSIONS)) - - def test_gets_version_from_correct_versions(self): - cntr = 0 - for version, expected in CORRECT_SEMANTIC_VERSIONS: - d = {"MedCAT Version": version} - with self.subTest(f"With version: {version}"): - got_version = get_version_from_modelcard(d) - self.assertEqual(got_version, expected) - cntr += 1 - self.assertEqual(cntr, len(CORRECT_SEMANTIC_VERSIONS)) - - -NEW_CDB_NAME = "cdb_new.dat" -CDB_PATH = os.path.join(os.path.dirname( - os.path.realpath(__file__)), "..", "..", "examples", NEW_CDB_NAME) -EXPECTED_CDB_VERSION = (1, 0, 0) - - -class VersionGettingFromCATTests(unittest.TestCase): - - def setUp(self) -> None: - self.cdb = CDB.load(CDB_PATH) - self.vocab = Vocab.load(os.path.join(os.path.dirname( - os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat")) - self.cdb.config.general.spacy_model = "en_core_web_md" - self.cdb.config.ner.min_name_len = 2 - self.cdb.config.ner.upper_case_limit_len = 3 - self.cdb.config.general.spell_check = True - self.cdb.config.linking.train_count_threshold = 10 - self.cdb.config.linking.similarity_threshold = 0.3 - self.cdb.config.linking.train = True - self.cdb.config.linking.disamb_length_limit = 5 - self.cdb.config.general.full_unlink = True - self.meta_cat_dir = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "tmp") - self.undertest = CAT( - cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[]) - - def test_gets_correct_version(self): - version = get_semantic_version_from_model(self.undertest) - self.assertEqual(EXPECTED_CDB_VERSION, version) - - -class VersionGetterFromCDBTests(unittest.TestCase): - - def test_gets_version_from_cdb(self): - version = get_version_from_cdb_dump(CDB_PATH) - self.assertEqual(EXPECTED_CDB_VERSION, version) - - -class VersionGettFromModelPackTests(unittest.TestCase): - - def test_gets_version_from_model_pack(self): - # not strictly speaking a ZIP, but should work currently - # since the folder exists - model_pack_zip = os.path.dirname(CDB_PATH) - version = get_version_from_modelpack_zip( - model_pack_zip, cdb_file_name=NEW_CDB_NAME) - self.assertEqual(EXPECTED_CDB_VERSION, version) - - -class VersioningFixTests(unittest.TestCase): - - def break_cdb(self): - with open(self.broken_cdb_path, 'rb') as rf: - data = dill.load(rf) - data['config']['linking']['filters']['cuis'] = {} - with open(self.broken_cdb_path, 'wb') as wf: - dill.dump(data, wf) - - def setUp(self) -> None: - self.temp_folder = tempfile.TemporaryDirectory() - self.broken_cdb_path = os.path.join(self.temp_folder.name, "cdb.dat") - self.new_temp_folder = tempfile.TemporaryDirectory() - shutil.copyfile(CDB_PATH, self.broken_cdb_path) - self.break_cdb() - - def tearDown(self) -> None: - self.temp_folder.cleanup() - self.new_temp_folder.cleanup() - - def test_new_format_does_not_change_when_upgraded(self): - fixer = ConfigUpgrader(os.path.dirname( - CDB_PATH), cdb_file_name=NEW_CDB_NAME) - fixer.upgrade(self.new_temp_folder.name) - old_cdb = CDB.load(CDB_PATH) - new_cdb = CDB.load(os.path.join( - self.new_temp_folder.name, NEW_CDB_NAME)) - self.assertEqual(old_cdb.config.get_hash(), new_cdb.config.get_hash()) - - def test_old_format_needs_upgrade(self): - fixer = ConfigUpgrader(self.temp_folder.name) - self.assertTrue(fixer.needs_upgrade()) - - def test_fixes_old_format(self): - fixer = ConfigUpgrader(self.temp_folder.name) - fixer.upgrade(self.new_temp_folder.name) - new_cdb = CDB.load(os.path.join(self.new_temp_folder.name, "cdb.dat")) - self.assertIsInstance(new_cdb, CDB)