diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index de89b35366a23..618d3c2fab12c 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -539,6 +539,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask + # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP + # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask attn_mask, add_qk_str = "", "" attn_mask_nodes_1 = self.model.match_parent_path( add_qk, @@ -570,6 +572,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 2, 1, 0, 0, 0], ) + attn_mask_nodes_7 = self.model.match_parent_path( + add_qk, + ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] @@ -588,6 +595,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): elif attn_mask_nodes_6 is not None: # The mask has already been reshaped to (B,N,S,T) add_qk_str = attn_mask_nodes_6[0].output[0] + elif attn_mask_nodes_7 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0]) else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index e7bcc19635f40..f9552e02d74b9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). -As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows: - -``` -# Before -if self.use_cache: - if past_key_values is not None: - input_ids = input_ids[:, -1:] - # Flatten the past_key_values (no need to flatten for models using multi-query attn) - - -# After -if self.use_cache: - if past_key_values is not None: - input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids - # Flatten the past_key_values (no need to flatten for models using multi-query attn) -``` - ### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx) Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2. @@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2. 1. PyTorch without `torch.compile`, FP32 ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-pt-eager \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ @@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \ 2. PyTorch with `torch.compile`, FP16 ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-pt-compile \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ @@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \ 3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ @@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \ 4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ @@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \ 5. ONNX Runtime, FP32, Microsoft custom export ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ @@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \ 6. ONNX Runtime, FP16, Microsoft custom export ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ @@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch ### Benchmark All You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example. ``` -python3 -m models.llama.benchmark_all \ +CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \ --hf-pt-eager \ --hf-pt-compile \ --hf-ort-dir-path ./llama2-7b-fp16/ \ diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index bc09b52574a27..71f52faa2c1e6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -4,6 +4,8 @@ import logging import os import shutil +import subprocess +import sys from itertools import chain import onnx @@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov only_onnxruntime=False, ) model_opt.save_model_to_file(output_path, use_external_data_format=True) + + # Run symbolic shape inference on optimized model to avoid shape errors during runtime + # Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output. + # After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output. + wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"] + source_cmd = [sys.executable, "../symbolic_shape_infer.py"] + symbolic_shape_infer_args = [ + "--input", + output_path, + "--output", + output_path, + "--auto_merge", + "--save_as_external_data", + "--all_tensors_to_one_file", + "--external_data_location", + os.path.basename(output_path) + ".data", + ] + + file_path = os.path.dirname(__file__) + if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")): + main_cmd = wheel_cmd + else: + main_cmd = source_cmd + subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510 + logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") if remove_model: remove_existing_model(input_path) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 94e0397116d1c..89b459c80beec 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, if i == rank % (world_size): l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) l_config.use_cache = True + l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer llama = AutoModelForCausalLM.from_pretrained( location, use_auth_token=use_auth_token, diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 4210f36982aef..b72c972e7a16a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/huggingface/optimum.git +optimum>=1.14.1 transformers>=4.33.2 torch>=2.2.0.dev20230920 onnx>=1.14.0