-
Notifications
You must be signed in to change notification settings - Fork 303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: use hf chat support #1047
Changes from 4 commits
87ab1ca
9449258
9ec5794
99c760a
900fc8f
86de116
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,6 +79,9 @@ def _load_client(self): | |
set_seed(_config.run.seed) | ||
|
||
pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline) | ||
pipeline_kwargs["truncation"] = ( | ||
True # this is forced to maintain existing pipeline expectations | ||
) | ||
self.generator = pipeline("text-generation", **pipeline_kwargs) | ||
if self.generator.tokenizer is None: | ||
# account for possible model without a stored tokenizer | ||
|
@@ -87,6 +90,10 @@ def _load_client(self): | |
self.generator.tokenizer = AutoTokenizer.from_pretrained( | ||
pipeline_kwargs["model"] | ||
) | ||
self.use_chat = ( | ||
hasattr(self.generator.tokenizer, "chat_template") | ||
and self.generator.tokenizer.chat_template is not None | ||
) | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not hasattr(self, "deprefix_prompt"): | ||
self.deprefix_prompt = self.name in models_to_deprefix | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if _config.loaded: | ||
|
@@ -98,6 +105,9 @@ def _load_client(self): | |
def _clear_client(self): | ||
self.generator = None | ||
|
||
def _format_chat_prompt(self, prompt: str) -> List[dict]: | ||
return [{"role": "user", "content": prompt}] | ||
|
||
def _call_model( | ||
self, prompt: str, generations_this_call: int = 1 | ||
) -> List[Union[str, None]]: | ||
|
@@ -106,13 +116,16 @@ def _call_model( | |
warnings.simplefilter("ignore", category=UserWarning) | ||
try: | ||
with torch.no_grad(): | ||
# workaround for pipeline to truncate the input | ||
encoded_prompt = self.generator.tokenizer(prompt, truncation=True) | ||
truncated_prompt = self.generator.tokenizer.decode( | ||
encoded_prompt["input_ids"], skip_special_tokens=True | ||
) | ||
# according to docs https://huggingface.co/docs/transformers/main/en/chat_templating | ||
# chat template should be automatically utilized if the pipeline tokenizer has support | ||
# and a properly formatted list[dict] is supplied | ||
if self.use_chat: | ||
formatted_prompt = self._format_chat_prompt(prompt) | ||
else: | ||
formatted_prompt = prompt | ||
Comment on lines
+123
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for future Erick and Jeffrey: we will want to merge |
||
|
||
raw_output = self.generator( | ||
truncated_prompt, | ||
formatted_prompt, | ||
pad_token_id=self.generator.tokenizer.eos_token_id, | ||
max_new_tokens=self.max_tokens, | ||
num_return_sequences=generations_this_call, | ||
|
@@ -127,10 +140,15 @@ def _call_model( | |
i["generated_text"] for i in raw_output | ||
] # generator returns 10 outputs by default in __init__ | ||
|
||
if self.use_chat: | ||
text_outputs = [_o[-1]["content"].strip() for _o in outputs] | ||
else: | ||
text_outputs = outputs | ||
|
||
Comment on lines
+144
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment not for this PR: there are multiple ways of representing conversations in garak, and attempt should be canonical for conversation history; I'm starting to consider patterns where attempt holds the conversation history still but where this can be read & written using other interfaces, like an OpenAI API messages dict list, or this HF style format. But mayeb the work is so simple that this current pattern works fine. |
||
if not self.deprefix_prompt: | ||
return outputs | ||
return text_outputs | ||
else: | ||
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs] | ||
return [re.sub("^" + re.escape(prompt), "", _o) for _o in text_outputs] | ||
|
||
|
||
class OptimumPipeline(Pipeline, HFCompatible): | ||
|
@@ -468,6 +486,12 @@ def _load_client(self): | |
self.name, padding_side="left" | ||
) | ||
|
||
# test tokenizer for `apply_chat_template` support | ||
self.use_chat = ( | ||
hasattr(self.tokenizer, "chat_template") | ||
and self.tokenizer.chat_template is not None | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is in a few places, is it worth factoring up to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the object format is slightly different def _supports_chat(self, tokenizer) -> bool:
return hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None I would like to defer this until I see a third usage or usage outside this module. |
||
self.generation_config = transformers.GenerationConfig.from_pretrained( | ||
self.name | ||
) | ||
|
@@ -492,14 +516,27 @@ def _call_model( | |
if self.top_k is not None: | ||
self.generation_config.top_k = self.top_k | ||
|
||
text_output = [] | ||
raw_text_output = [] | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", category=UserWarning) | ||
with torch.no_grad(): | ||
if self.use_chat: | ||
formatted_prompt = self.tokenizer.apply_chat_template( | ||
self._format_chat_prompt(prompt), | ||
tokenize=False, | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
add_generation_prompt=True, | ||
) | ||
else: | ||
formatted_prompt = prompt | ||
|
||
inputs = self.tokenizer( | ||
prompt, truncation=True, return_tensors="pt" | ||
formatted_prompt, truncation=True, return_tensors="pt" | ||
).to(self.device) | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
prefix_prompt = self.tokenizer.decode( | ||
inputs["input_ids"][0], skip_special_tokens=True | ||
) | ||
|
||
try: | ||
outputs = self.model.generate( | ||
**inputs, generation_config=self.generation_config | ||
|
@@ -512,14 +549,22 @@ def _call_model( | |
return returnval | ||
else: | ||
raise e | ||
text_output = self.tokenizer.batch_decode( | ||
raw_text_output = self.tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, device=self.device | ||
) | ||
|
||
if self.use_chat: | ||
text_output = [ | ||
re.sub("^" + re.escape(prefix_prompt), "", i).strip() | ||
for i in raw_text_output | ||
] | ||
else: | ||
text_output = raw_text_output | ||
Comment on lines
+558
to
+564
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Part of me (albeit a part with limited convictions) feels like there HAS to be a better way to handle this. HF seems to be managing it internally with their There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idk man the format uses Another option could be to do My experience also is that things like
|
||
|
||
if not self.deprefix_prompt: | ||
return text_output | ||
else: | ||
return [re.sub("^" + re.escape(prompt), "", i) for i in text_output] | ||
return [re.sub("^" + re.escape(prefix_prompt), "", i) for i in text_output] | ||
|
||
|
||
class LLaVA(Generator, HFCompatible): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,18 @@ def test_pipeline(hf_generator_config): | |
assert isinstance(item, str) | ||
|
||
|
||
def test_pipeline_chat(mocker, hf_generator_config): | ||
# uses a ~350M model with chat support | ||
g = garak.generators.huggingface.Pipeline( | ||
"microsoft/DialoGPT-small", config_root=hf_generator_config | ||
) | ||
mock_format = mocker.patch.object( | ||
g, "_format_chat_prompt", wraps=g._format_chat_prompt | ||
) | ||
g.generate("Hello world!") | ||
mock_format.assert_called_once() | ||
|
||
|
||
def test_inference(mocker, hf_mock_response, hf_generator_config): | ||
model_name = "gpt2" | ||
mock_request = mocker.patch.object( | ||
|
@@ -121,6 +133,18 @@ def test_model(hf_generator_config): | |
assert item is None # gpt2 is known raise exception returning `None` | ||
|
||
|
||
def test_model_chat(mocker, hf_generator_config): | ||
# uses a ~350M model with chat support | ||
g = garak.generators.huggingface.Model( | ||
"microsoft/DialoGPT-small", config_root=hf_generator_config | ||
) | ||
mock_format = mocker.patch.object( | ||
g, "_format_chat_prompt", wraps=g._format_chat_prompt | ||
) | ||
g.generate("Hello world!") | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mock_format.assert_called_once() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a great test for where we are now -- would be nice to have something where we could specifically test the deprefixing but there is some unavoidable stochasticity here that I'm not sure how to account for. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point, in the vein of wanting less reliance on actual models I think adding a test in the future that mocks a model response completely by offering a known test result should be a reasonable way to approach adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice find re: the model. I wonder if we could use something like the |
||
|
||
|
||
def test_select_hf_device(): | ||
from garak.generators.huggingface import HFCompatible | ||
import torch | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes I wish this were just the default for HF but c'est la vie. We may want to consider whether
truncation = True
requiresmax_len
to be set -- I've had HF yell at me for not specifying both but it may be a corner case that I encountered.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HF is kinda yell-y, and over the course of garak dev, HF has varied what it yells about. It can also be the case that some models require certain param combos to operate, while others will find the same param combo utterly intolerable. This has led to a style where one tries to do the right thing in garak, and tries to listen to HF warnings a little less.