Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LLaMA attention fusions #19200

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):

# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
attn_mask, add_qk_str = "", ""
attn_mask_nodes_1 = self.model.match_parent_path(
add_qk,
Expand Down Expand Up @@ -570,6 +572,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 2, 1, 0, 0, 0],
)
attn_mask_nodes_7 = self.model.match_parent_path(
add_qk,
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 0, 0, 0, 1, 0, 0, 0],
)
if attn_mask_nodes_1 is not None:
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
attn_mask = slice_mask_1.output[0]
Expand All @@ -588,6 +595,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
elif attn_mask_nodes_6 is not None:
# The mask has already been reshaped to (B,N,S,T)
add_qk_str = attn_mask_nodes_6[0].output[0]
elif attn_mask_nodes_7 is not None:
# Reshape from (B,1,S,T) to (B,N,S,T)
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
else:
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
return
Expand Down
31 changes: 7 additions & 24 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama

To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).

As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:

```
# Before
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# Flatten the past_key_values (no need to flatten for models using multi-query attn)


# After
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
```

### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)

Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
Expand Down Expand Up @@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2.

1. PyTorch without `torch.compile`, FP32
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-eager \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
Expand All @@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \

2. PyTorch with `torch.compile`, FP16
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-compile \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
Expand All @@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \

3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
Expand All @@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \

4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
Expand All @@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \

5. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
Expand All @@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \

6. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
Expand Down Expand Up @@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch
### Benchmark All
You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example.
```
python3 -m models.llama.benchmark_all \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
--hf-pt-eager \
--hf-pt-compile \
--hf-ort-dir-path ./llama2-7b-fp16/ \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import os
import shutil
import subprocess
import sys
from itertools import chain

import onnx
Expand Down Expand Up @@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov
only_onnxruntime=False,
)
model_opt.save_model_to_file(output_path, use_external_data_format=True)

# Run symbolic shape inference on optimized model to avoid shape errors during runtime
# Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
# After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
symbolic_shape_infer_args = [
"--input",
output_path,
"--output",
output_path,
"--auto_merge",
"--save_as_external_data",
"--all_tensors_to_one_file",
"--external_data_location",
os.path.basename(output_path) + ".data",
]

file_path = os.path.dirname(__file__)
if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
main_cmd = wheel_cmd
else:
main_cmd = source_cmd
subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510

logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
if remove_model:
remove_existing_model(input_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32,
if i == rank % (world_size):
l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=use_auth_token,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
git+https://github.com/huggingface/optimum.git
optimum>=1.14.1
transformers>=4.33.2
torch>=2.2.0.dev20230920
onnx>=1.14.0
Expand Down
Loading