From 739245eeef8a955ae9362f74431eea53e984057b Mon Sep 17 00:00:00 2001 From: vladd-bit Date: Fri, 16 Feb 2024 17:40:09 +0000 Subject: [PATCH] Fixed model card info not being displayed for model_packs. --- .../nlp_processor/medcat_processor.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/medcat_service/nlp_processor/medcat_processor.py b/medcat_service/nlp_processor/medcat_processor.py index 3208052..216c25e 100644 --- a/medcat_service/nlp_processor/medcat_processor.py +++ b/medcat_service/nlp_processor/medcat_processor.py @@ -9,6 +9,7 @@ import simplejson as json from medcat.cat import CAT from medcat.cdb import CDB +from medcat.config import Config from medcat.meta_cat import MetaCAT from medcat.utils.ner.deid import DeIdModel from medcat.vocab import Vocab @@ -65,7 +66,8 @@ def __init__(self): self.bulk_nproc = int(os.getenv("APP_BULK_NPROC", 8)) self.torch_threads = int(os.getenv("APP_TORCH_THREADS", -1)) - self.DEID_MODE = os.getenv("DEID_MODE", False) + self.DEID_MODE = os.getenv("DEID_MODE", "False") + self.model_card_info = {} # this is available to constrain torch threads when there # isn't a GPU @@ -79,7 +81,6 @@ def __init__(self): self.cat = self._create_cat() self.cat.train = os.getenv("APP_TRAINING_MODE", False) - self.log.info("MedCAT processor is ready") def get_app_info(self): @@ -90,7 +91,9 @@ def get_app_info(self): return {"service_app_name": self.app_name, "service_language": self.app_lang, "service_version": self.app_version, - "service_model": self.app_model} + "service_model": self.app_model, + "model_card_info": self.model_card_info + } def process_entities(self, entities): if type(entities) is dict: @@ -180,9 +183,8 @@ def process_content_bulk(self, content): start_time_ns = time.time_ns() try: - if self.DEID_MODE: + if eval(self.DEID_MODE): ann_res = self.cat.deid_text() - pass else: ann_res = self.cat.multiprocessing( MedCatProcessor._generate_input_doc(content, invalid_doc_ids), nproc=nproc) @@ -214,6 +216,13 @@ def retrain_medcat(self, content, replace_cdb): return {"results": [p, r, f1, tp_dict, fp_dict, fn_dict]} + def _populate_model_card_info(self, config: Config): + self.model_card_info["ontologies"] = config.version.ontology \ + if (type(config.version.ontology) == list) else str(config.version.ontology) + self.model_card_info["meta_cat_model_names"] = [i["Category Name"] for i in config.version.meta_cats] \ + if (type(config.version.meta_cats) == list) else str(config.version.meta_cats) + self.model_card_info["model_last_modified_on"] = str(config.version.last_modified) + # helper MedCAT methods # def _create_cat(self): @@ -235,13 +244,19 @@ def _create_cat(self): self.log.info("Loading model pack...") cat = CAT.load_model_pack(model_pack_path) - if self.DEID_MODE: + if eval(self.DEID_MODE): cat = DeIdModel.load_model_pack(model_pack_path) # Apply CUI filter if provided if os.getenv("APP_MODEL_CUI_FILTER_PATH", None) is not None: self.log.debug("Applying CUI filter ...") cat.cdb.filter_by_cui(cuis_to_keep) + + if self.app_model.lower() in ["", "unknown", "medmen"]: + self.app_model = cat.config.version.id + + self._populate_model_card_info(cat.config) + return cat else: self.log.info("APP_MEDCAT_MODEL_PACK not set, skipping....") @@ -293,10 +308,15 @@ def _create_cat(self): if cat: meta_models.extend(cat._meta_cats) + if self.app_model.lower() in [None, "unknown"]: + self.app_model = cdb.config.version.id + config.general["log_level"] = os.getenv("LOG_LEVEL", logging.INFO) cat = CAT(cdb=cdb, config=config, vocab=vocab, meta_cats=meta_models) + self._populate_model_card_info(cat.config) + return cat # helper generator functions to avoid multiple copies of data