From 5eebd09d41eb97dd5e9a18a183ac9e461748c5f9 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:09:24 -0800 Subject: [PATCH] Update LLaMA attention fusions (#19200) ### Description This PR updates the LLaMA-2 attention fusions by adding the following. - Loading the PyTorch model from Hugging Face with the `LlamaAttention` class before exporting - Updating the attention mask pattern matching to support another case This PR also fixes [this issue](https://github.com/microsoft/onnxruntime/issues/19040). ### Motivation and Context Recent changes to Hugging Face's `transformers` library break the existing pattern matching. Since the attention fusions aim to change the graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to `LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately it does not matter what nodes comprise the `Set of Attention Nodes` because they will all be removed and replaced by the `Attention Op` in the end. Therefore, it does not matter whether the `LlamaAttention` class or a different attention class is used to load the PyTorch model before exporting because the expected graphs after the attention fusions will look identical no matter the attention class chosen. By loading the PyTorch model with the `LlamaAttention` class instead of other attention classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then exporting it to ONNX, the existing pattern matching will continue to work. --- .../transformers/fusion_rotary_attention.py | 10 ++++++ .../tools/transformers/models/llama/README.md | 31 +++++-------------- .../models/llama/convert_to_onnx.py | 27 ++++++++++++++++ .../transformers/models/llama/llama_torch.py | 1 + .../models/llama/requirements.txt | 2 +- 5 files changed, 46 insertions(+), 25 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index de89b35366a23..618d3c2fab12c 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -539,6 +539,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask + # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP + # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask attn_mask, add_qk_str = "", "" attn_mask_nodes_1 = self.model.match_parent_path( add_qk, @@ -570,6 +572,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 2, 1, 0, 0, 0], ) + attn_mask_nodes_7 = self.model.match_parent_path( + add_qk, + ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] @@ -588,6 +595,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif attn_mask_nodes_6 is not None: # The mask has already been reshaped to (B,N,S,T) add_qk_str = attn_mask_nodes_6[0].output[0] + elif attn_mask_nodes_7 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0]) else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index e7bcc19635f40..f9552e02d74b9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). -As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows: - -``` -# Before -if self.use_cache: - if past_key_values is not None: - input_ids = input_ids[:, -1:] - # Flatten the past_key_values (no need to flatten for models using multi-query attn) - - -# After -if self.use_cache: - if past_key_values is not None: - input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids - # Flatten the past_key_values (no need to flatten for models using multi-query attn) -``` - ### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx) Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2. @@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2. 1. PyTorch without `torch.compile`, FP32 ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-pt-eager \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ @@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \ 2. PyTorch with `torch.compile`, FP16 ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-pt-compile \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ @@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \ 3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ @@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \ 4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ @@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \ 5. ONNX Runtime, FP32, Microsoft custom export ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ @@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \ 6. ONNX Runtime, FP16, Microsoft custom export ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ @@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch ### Benchmark All You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example. ``` -python3 -m models.llama.benchmark_all \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \ --hf-pt-eager \ --hf-pt-compile \ --hf-ort-dir-path ./llama2-7b-fp16/ \ diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index bc09b52574a27..71f52faa2c1e6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -4,6 +4,8 @@ import logging import os import shutil +import subprocess +import sys from itertools import chain import onnx @@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov only_onnxruntime=False, ) model_opt.save_model_to_file(output_path, use_external_data_format=True) + + # Run symbolic shape inference on optimized model to avoid shape errors during runtime + # Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output. + # After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output. + wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"] + source_cmd = [sys.executable, "../symbolic_shape_infer.py"] + symbolic_shape_infer_args = [ + "--input", + output_path, + "--output", + output_path, + "--auto_merge", + "--save_as_external_data", + "--all_tensors_to_one_file", + "--external_data_location", + os.path.basename(output_path) + ".data", + ] + + file_path = os.path.dirname(__file__) + if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")): + main_cmd = wheel_cmd + else: + main_cmd = source_cmd + subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510 + logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") if remove_model: remove_existing_model(input_path) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 94e0397116d1c..89b459c80beec 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, if i == rank % (world_size): l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) l_config.use_cache = True + l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer llama = AutoModelForCausalLM.from_pretrained( location, use_auth_token=use_auth_token, diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 4210f36982aef..b72c972e7a16a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/huggingface/optimum.git +optimum>=1.14.1 transformers>=4.33.2 torch>=2.2.0.dev20230920 onnx>=1.14.0