Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: use hf chat support #1047

Merged
merged 6 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 57 additions & 12 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def _load_client(self):
set_seed(_config.run.seed)

pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
pipeline_kwargs["truncation"] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes I wish this were just the default for HF but c'est la vie. We may want to consider whether truncation = True requires max_len to be set -- I've had HF yell at me for not specifying both but it may be a corner case that I encountered.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF is kinda yell-y, and over the course of garak dev, HF has varied what it yells about. It can also be the case that some models require certain param combos to operate, while others will find the same param combo utterly intolerable. This has led to a style where one tries to do the right thing in garak, and tries to listen to HF warnings a little less.

True # this is forced to maintain existing pipeline expectations
)
self.generator = pipeline("text-generation", **pipeline_kwargs)
if self.generator.tokenizer is None:
# account for possible model without a stored tokenizer
Expand All @@ -87,6 +90,10 @@ def _load_client(self):
self.generator.tokenizer = AutoTokenizer.from_pretrained(
pipeline_kwargs["model"]
)
self.use_chat = (
hasattr(self.generator.tokenizer, "chat_template")
and self.generator.tokenizer.chat_template is not None
)
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
if _config.loaded:
Expand All @@ -98,6 +105,9 @@ def _load_client(self):
def _clear_client(self):
self.generator = None

def _format_chat_prompt(self, prompt: str) -> List[dict]:
return [{"role": "user", "content": prompt}]

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand All @@ -106,13 +116,16 @@ def _call_model(
warnings.simplefilter("ignore", category=UserWarning)
try:
with torch.no_grad():
# workaround for pipeline to truncate the input
encoded_prompt = self.generator.tokenizer(prompt, truncation=True)
truncated_prompt = self.generator.tokenizer.decode(
encoded_prompt["input_ids"], skip_special_tokens=True
)
# according to docs https://huggingface.co/docs/transformers/main/en/chat_templating
# chat template should be automatically utilized if the pipeline tokenizer has support
# and a properly formatted list[dict] is supplied
if self.use_chat:
formatted_prompt = self._format_chat_prompt(prompt)
else:
formatted_prompt = prompt
Comment on lines +123 to +126
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future Erick and Jeffrey: we will want to merge ConversationalPipeline into this at some point given HF's deprecation of the "conversational" pipeline and Conversation object. At that point, we'll also want to validate that prompt is a string.


raw_output = self.generator(
truncated_prompt,
formatted_prompt,
pad_token_id=self.generator.tokenizer.eos_token_id,
max_new_tokens=self.max_tokens,
num_return_sequences=generations_this_call,
Expand All @@ -127,10 +140,15 @@ def _call_model(
i["generated_text"] for i in raw_output
] # generator returns 10 outputs by default in __init__

if self.use_chat:
text_outputs = [_o[-1]["content"].strip() for _o in outputs]
else:
text_outputs = outputs

Comment on lines +144 to +148
Copy link
Collaborator

@leondz leondz Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment not for this PR: there are multiple ways of representing conversations in garak, and attempt should be canonical for conversation history; I'm starting to consider patterns where attempt holds the conversation history still but where this can be read & written using other interfaces, like an OpenAI API messages dict list, or this HF style format. But mayeb the work is so simple that this current pattern works fine.

if not self.deprefix_prompt:
return outputs
return text_outputs
else:
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs]
return [re.sub("^" + re.escape(prompt), "", _o) for _o in text_outputs]


class OptimumPipeline(Pipeline, HFCompatible):
Expand Down Expand Up @@ -468,6 +486,12 @@ def _load_client(self):
self.name, padding_side="left"
)

# test tokenizer for `apply_chat_template` support
self.use_chat = (
hasattr(self.tokenizer, "chat_template")
and self.tokenizer.chat_template is not None
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is in a few places, is it worth factoring up to HFCompatible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the object format is slightly different self.tokenizer vs self.generator.tokenizer there is an abstraction possible here for _supports_chat(tokenizer) though not sold on the value yet:

def _supports_chat(self, tokenizer) -> bool:
    return hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None

I would like to defer this until I see a third usage or usage outside this module.

self.generation_config = transformers.GenerationConfig.from_pretrained(
self.name
)
Expand All @@ -492,14 +516,27 @@ def _call_model(
if self.top_k is not None:
self.generation_config.top_k = self.top_k

text_output = []
raw_text_output = []
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with torch.no_grad():
if self.use_chat:
formatted_prompt = self.tokenizer.apply_chat_template(
self._format_chat_prompt(prompt),
tokenize=False,
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
add_generation_prompt=True,
)
else:
formatted_prompt = prompt

inputs = self.tokenizer(
prompt, truncation=True, return_tensors="pt"
formatted_prompt, truncation=True, return_tensors="pt"
).to(self.device)
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved

prefix_prompt = self.tokenizer.decode(
inputs["input_ids"][0], skip_special_tokens=True
)

try:
outputs = self.model.generate(
**inputs, generation_config=self.generation_config
Expand All @@ -512,14 +549,22 @@ def _call_model(
return returnval
else:
raise e
text_output = self.tokenizer.batch_decode(
raw_text_output = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, device=self.device
)

if self.use_chat:
text_output = [
re.sub("^" + re.escape(prefix_prompt), "", i).strip()
for i in raw_text_output
]
else:
text_output = raw_text_output
Comment on lines +558 to +564
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of me (albeit a part with limited convictions) feels like there HAS to be a better way to handle this. HF seems to be managing it internally with their pipelines, but this seems fine for the time being.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk man the format uses \n as a special char but doesn't seem to do mush escaping, many bets are off

Another option could be to do \n-based escaping and parsing of messages line by line, assuming [A-Z][a-z]+: is a metadata line

My experience also is that things like prefix_prompt get mangled in other code sometimes. I don't have great suggestions for how to check this well, other than:

  • continue doing qualitative reviews of prompt:output pairs for suspicious items;
  • migrate to a data format with clearer separation between data and metadata;
  • badger HF for a reference implementation;
  • give notice to users that HF Chat is "janky".


if not self.deprefix_prompt:
return text_output
else:
return [re.sub("^" + re.escape(prompt), "", i) for i in text_output]
return [re.sub("^" + re.escape(prefix_prompt), "", i) for i in text_output]


class LLaVA(Generator, HFCompatible):
Expand Down
24 changes: 24 additions & 0 deletions tests/generators/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def test_pipeline(hf_generator_config):
assert isinstance(item, str)


def test_pipeline_chat(mocker, hf_generator_config):
# uses a ~350M model with chat support
g = garak.generators.huggingface.Pipeline(
"microsoft/DialoGPT-small", config_root=hf_generator_config
)
mock_format = mocker.patch.object(
g, "_format_chat_prompt", wraps=g._format_chat_prompt
)
g.generate("Hello world!")
mock_format.assert_called_once()


def test_inference(mocker, hf_mock_response, hf_generator_config):
model_name = "gpt2"
mock_request = mocker.patch.object(
Expand Down Expand Up @@ -121,6 +133,18 @@ def test_model(hf_generator_config):
assert item is None # gpt2 is known raise exception returning `None`


def test_model_chat(mocker, hf_generator_config):
# uses a ~350M model with chat support
g = garak.generators.huggingface.Model(
"microsoft/DialoGPT-small", config_root=hf_generator_config
)
mock_format = mocker.patch.object(
g, "_format_chat_prompt", wraps=g._format_chat_prompt
)
g.generate("Hello world!")
jmartin-tech marked this conversation as resolved.
Show resolved Hide resolved
mock_format.assert_called_once()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great test for where we are now -- would be nice to have something where we could specifically test the deprefixing but there is some unavoidable stochasticity here that I'm not sure how to account for.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, in the vein of wanting less reliance on actual models I think adding a test in the future that mocks a model response completely by offering a known test result should be a reasonable way to approach adding a unit test for validation of deprefix function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find re: the model. I wonder if we could use something like the lipsum module (which we already depend upon) to mock up a test HF generator that could be used in like Many of our tests. Maybe for jan?



def test_select_hf_device():
from garak.generators.huggingface import HFCompatible
import torch
Expand Down
Loading