From 7d97cca8dde8b1d5a21a4df56c340cfe3cdd9fb4 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Fri, 11 Oct 2024 01:16:31 -0700 Subject: [PATCH] Generate using exported model and enable gemma2-2b in ExecuTorch (#33707) * Generate using exported model and enable gemma2-2b in ExecuTorch * [run_slow] gemma, gemma2 * truncate expected output message * Bump required torch version to support gemma2 export * [run_slow] gemma, gemma2 --------- Co-authored-by: Guang Yang --- src/transformers/integrations/executorch.py | 50 +++++++++++++++++ tests/models/gemma/test_modeling_gemma.py | 62 +++++++++++++++++++++ tests/models/gemma2/test_modeling_gemma2.py | 56 +++++++++++++++++++ 3 files changed, 168 insertions(+) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 4adfcc39a4b1fa..c0adff386f6312 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -114,6 +114,56 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): ) return outs.logits + @staticmethod + def generate( + exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int + ) -> torch.Tensor: + """ + Generate a sequence of tokens using an exported program. + + This util function is designed to test exported models by simulating the generation process. + It processes the input prompt tokens sequentially (no parallel prefill). + This generate function is not intended to replace the original `generate` method, and the support + for leveraging the original `generate` is potentially planed! + + Args: + exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. + prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs. + max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation + length is limited by both `max_new_tokens` and the model's cache size. + + Returns: + torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens. + """ + prompt_token_len = prompt_token_ids.shape[-1] + max_generation_length = prompt_token_len + max_new_tokens + for buffer_name, buffer in exported_program.named_buffers(): + if buffer_name.startswith("static_cache.key_cache"): + max_cache_len = buffer.shape[2] + max_generation_length = min(max_generation_length, max_cache_len) + break + + response_tokens = [] + for input_pos in range(min(max_generation_length, prompt_token_len)): + result = exported_program.module().forward( + input_ids=prompt_token_ids[:, input_pos : input_pos + 1], + cache_position=torch.tensor([input_pos], dtype=torch.long), + ) + response_tokens.append(prompt_token_ids[0][input_pos].item()) + + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + + while len(response_tokens) < max_generation_length: + result = exported_program.module().forward( + input_ids=torch.tensor([[current_token]], dtype=torch.long), + cache_position=torch.tensor([len(response_tokens)], dtype=torch.long), + ) + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + + return torch.tensor([response_tokens], dtype=torch.long) + def convert_and_export_with_cache( model: PreTrainedModel, diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 1a6efe5735275e..a888bdcd3bc7be 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -21,6 +21,7 @@ from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available +from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( is_flaky, require_bitsandbytes, @@ -841,6 +842,67 @@ def test_compile_static_cache(self): static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + @slow + @require_read_token + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found", + ] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = GemmaForCausalLM.from_pretrained( + "google/gemma-2b", + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompts = ["Hello I am doing"] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + eager + eager_generated_ids = model.generate( + **prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation + ) + eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) + def test_model_2b_bf16_dola(self): model_id = "google/gemma-2b" # ground truth text generated with dola_layers="low", repetition_penalty=1.2 diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 4e7b3553460f89..8f9a918dca0082 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -16,10 +16,12 @@ import unittest +from packaging import version from parameterized import parameterized from pytest import mark from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline +from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( require_flash_attn, require_read_token, @@ -306,3 +308,57 @@ def test_model_9b_flash_attn(self): output_text = tokenizer.batch_decode(output, skip_special_tokens=False) self.assertEqual(output_text, EXPECTED_TEXTS) + + @slow + @require_read_token + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.5.0"): + self.skipTest(reason="This test requires torch >= 2.5 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = [ + "Hello I am doing a project for my school and I need to know how to make a program that will take a number", + ] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = AutoModelForCausalLM.from_pretrained( + "google/gemma-2-2b", + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompts = ["Hello I am doing"] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)