From dba04b9639408af8b41ff887b99abb050eb08d51 Mon Sep 17 00:00:00 2001 From: eriknovak Date: Thu, 23 May 2024 21:28:21 +0200 Subject: [PATCH] Add pattern restriction to MaskLabelGenerator + generalize model use --- .../anonymize/generators/mask_label_generator.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/anonipy/anonymize/generators/mask_label_generator.py b/anonipy/anonymize/generators/mask_label_generator.py index d143bde..cc8390b 100644 --- a/anonipy/anonymize/generators/mask_label_generator.py +++ b/anonipy/anonymize/generators/mask_label_generator.py @@ -34,16 +34,18 @@ def __init__( ) use_gpu = False + # prepare the fill-mask pipeline and store the mask token model, tokenizer = self._prepare_model_and_tokenizer(model_name, use_gpu) + self.mask_token = tokenizer.mask_token self.pipeline = pipeline( - "fill-mask", model=model, tokenizer=tokenizer, top_k=10 + "fill-mask", model=model, tokenizer=tokenizer, top_k=40 ) def generate(self, entity: Entity, text: str, *args, **kwargs): masks = self._create_masks(entity) input_texts = self._prepare_generate_inputs(masks, text) suggestions = self.pipeline(input_texts) - return self._create_substitute(masks, suggestions) + return self._create_substitute(entity, masks, suggestions) # ================================= # Private methods @@ -67,7 +69,7 @@ def _create_masks(self, entity: Entity): { "true_text": chunks[idx], "mask_text": " ".join( - chunks[0:idx] + [""] + chunks[idx + 1 :] + chunks[0:idx] + [self.mask_token] + chunks[idx + 1 :] ), "start_index": entity.start_index, "end_index": entity.end_index, @@ -90,14 +92,15 @@ def _prepare_generate_inputs(self, masks, text): for m in masks ] - def _create_substitute(self, masks, suggestions): + def _create_substitute(self, entity: Entity, masks, suggestions): substitute_chunks = [] for mask, suggestion in zip(masks, suggestions): suggestion = suggestion if type(suggestion) == list else [suggestion] viable_suggestions = list( filter( - lambda x: x["token_str"] not in STOPWORDS - and x["token_str"] != mask["true_text"], + lambda x: x["token_str"] != mask["true_text"] + and re.match(entity.regex, x["token_str"]) + and x["token_str"] not in STOPWORDS, suggestion, ) )