Skip to content

Commit

Permalink
Generate using exported model and enable gemma2-2b in ExecuTorch (hug…
Browse files Browse the repository at this point in the history
…gingface#33707)

* Generate using exported model and enable gemma2-2b in ExecuTorch

* [run_slow] gemma, gemma2

* truncate expected output message

* Bump required torch version to support gemma2 export

* [run_slow] gemma, gemma2

---------

Co-authored-by: Guang Yang <[email protected]>
  • Loading branch information
guangy10 and Guang Yang authored Oct 11, 2024
1 parent 70b07d9 commit 7d97cca
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,56 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
)
return outs.logits

@staticmethod
def generate(
exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int
) -> torch.Tensor:
"""
Generate a sequence of tokens using an exported program.
This util function is designed to test exported models by simulating the generation process.
It processes the input prompt tokens sequentially (no parallel prefill).
This generate function is not intended to replace the original `generate` method, and the support
for leveraging the original `generate` is potentially planed!
Args:
exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs.
max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation
length is limited by both `max_new_tokens` and the model's cache size.
Returns:
torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
"""
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break

response_tokens = []
for input_pos in range(min(max_generation_length, prompt_token_len)):
result = exported_program.module().forward(
input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
cache_position=torch.tensor([input_pos], dtype=torch.long),
)
response_tokens.append(prompt_token_ids[0][input_pos].item())

current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)

while len(response_tokens) < max_generation_length:
result = exported_program.module().forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)

return torch.tensor([response_tokens], dtype=torch.long)


def convert_and_export_with_cache(
model: PreTrainedModel,
Expand Down
62 changes: 62 additions & 0 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from packaging import version

from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
is_flaky,
require_bitsandbytes,
Expand Down Expand Up @@ -841,6 +842,67 @@ def test_compile_static_cache(self):
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)

@slow
@require_read_token
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]

# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = GemmaForCausalLM.from_pretrained(
"google/gemma-2b",
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
)

prompts = ["Hello I am doing"]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + eager
eager_generated_ids = model.generate(
**prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation
)
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

def test_model_2b_bf16_dola(self):
model_id = "google/gemma-2b"
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
Expand Down
56 changes: 56 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import unittest

from packaging import version
from parameterized import parameterized
from pytest import mark

from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
Expand Down Expand Up @@ -306,3 +308,57 @@ def test_model_9b_flash_attn(self):
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)

@slow
@require_read_token
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.5.0"):
self.skipTest(reason="This test requires torch >= 2.5 to run.")

from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number",
]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]

# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
)

prompts = ["Hello I am doing"]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]

# Static Cache + export
exported_program = convert_and_export_with_cache(model)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

0 comments on commit 7d97cca

Please sign in to comment.