Skip to content

Commit

Permalink
update model
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Aug 19, 2024
1 parent 589f7e0 commit 4da2294
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions models/ft_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def load_local_model(self, model_path=None):
self._tokenizer.pad_token = self._tokenizer.eos_token
self._lm.config.pad_token_id = self._lm.config.eos_token_id

def query(self, text):
def query(self, text, new_str_only=False):
"""
Query an open-source model with a given text prompt.
Expand All @@ -106,11 +106,15 @@ def query(self, text):
# top_k=sampling_args.top_k,
# top_p=sampling_args.top_p,
# output_scores=True,
return_dict_in_generate=True
return_dict_in_generate=True,

)

# Decode the generated text back to a readable string
generated_text = self._tokenizer.decode(output.sequences[0], skip_special_tokens=True)
if new_str_only:
generated_text = self._tokenizer.decode(output.sequences[0][len(input_ids[0]):], skip_special_tokens=True)
else:
generated_text = self._tokenizer.decode(output.sequences[0], skip_special_tokens=True)
return generated_text

def evaluate(self, text, tokenized=False):
Expand Down

0 comments on commit 4da2294

Please sign in to comment.