From 4da2294e946cd83850a1a3b8fc631f8e14b5baf5 Mon Sep 17 00:00:00 2001 From: Qinbin Li Date: Mon, 19 Aug 2024 12:00:59 +0800 Subject: [PATCH] update model --- models/ft_clm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/models/ft_clm.py b/models/ft_clm.py index 7976dee..a6580be 100644 --- a/models/ft_clm.py +++ b/models/ft_clm.py @@ -80,7 +80,7 @@ def load_local_model(self, model_path=None): self._tokenizer.pad_token = self._tokenizer.eos_token self._lm.config.pad_token_id = self._lm.config.eos_token_id - def query(self, text): + def query(self, text, new_str_only=False): """ Query an open-source model with a given text prompt. @@ -106,11 +106,15 @@ def query(self, text): # top_k=sampling_args.top_k, # top_p=sampling_args.top_p, # output_scores=True, - return_dict_in_generate=True + return_dict_in_generate=True, + ) # Decode the generated text back to a readable string - generated_text = self._tokenizer.decode(output.sequences[0], skip_special_tokens=True) + if new_str_only: + generated_text = self._tokenizer.decode(output.sequences[0][len(input_ids[0]):], skip_special_tokens=True) + else: + generated_text = self._tokenizer.decode(output.sequences[0], skip_special_tokens=True) return generated_text def evaluate(self, text, tokenized=False):