Skip to content

Commit

Permalink
Update LLaMA attention fusions (#19200)
Browse files Browse the repository at this point in the history
### Description
This PR updates the LLaMA-2 attention fusions by adding the following.

- Loading the PyTorch model from Hugging Face with the `LlamaAttention`
class before exporting
- Updating the attention mask pattern matching to support another case

This PR also fixes [this
issue](#19040).

### Motivation and Context
Recent changes to Hugging Face's `transformers` library break the
existing pattern matching. Since the attention fusions aim to change the
graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to
`LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately
it does not matter what nodes comprise the `Set of Attention Nodes`
because they will all be removed and replaced by the `Attention Op` in
the end.

Therefore, it does not matter whether the `LlamaAttention` class or a
different attention class is used to load the PyTorch model before
exporting because the expected graphs after the attention fusions will
look identical no matter the attention class chosen. By loading the
PyTorch model with the `LlamaAttention` class instead of other attention
classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then
exporting it to ONNX, the existing pattern matching will continue to
work.
  • Loading branch information
kunal-vaishnavi authored and rachguo committed Jan 23, 2024
1 parent 2b86515 commit 5eebd09
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 25 deletions.
10 changes: 10 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):

# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
attn_mask, add_qk_str = "", ""
attn_mask_nodes_1 = self.model.match_parent_path(
add_qk,
Expand Down Expand Up @@ -570,6 +572,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 2, 1, 0, 0, 0],
)
attn_mask_nodes_7 = self.model.match_parent_path(
add_qk,
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 0, 0, 0, 1, 0, 0, 0],
)
if attn_mask_nodes_1 is not None:
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
attn_mask = slice_mask_1.output[0]
Expand All @@ -588,6 +595,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
elif attn_mask_nodes_6 is not None:
# The mask has already been reshaped to (B,N,S,T)
add_qk_str = attn_mask_nodes_6[0].output[0]
elif attn_mask_nodes_7 is not None:
# Reshape from (B,1,S,T) to (B,N,S,T)
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
else:
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
return
Expand Down
31 changes: 7 additions & 24 deletions onnxruntime/python/tools/transformers/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,6 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama

To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf).

As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows:

```
# Before
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
# After
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
```

### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx)

Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2.
Expand Down Expand Up @@ -254,7 +237,7 @@ Here are some examples of how you can benchmark LLaMA-2.

1. PyTorch without `torch.compile`, FP32
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-eager \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp32 \
Expand All @@ -266,7 +249,7 @@ python3 -m models.llama.benchmark \

2. PyTorch with `torch.compile`, FP16
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-pt-compile \
--model-name meta-llama/Llama-2-7b-hf \
--precision fp16 \
Expand All @@ -278,7 +261,7 @@ python3 -m models.llama.benchmark \

3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
Expand All @@ -291,7 +274,7 @@ python3 -m models.llama.benchmark \

4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type hf-ort \
--hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \
--model-name meta-llama/Llama-2-7b-hf \
Expand All @@ -304,7 +287,7 @@ python3 -m models.llama.benchmark \

5. ONNX Runtime, FP32, Microsoft custom export
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \
--model-name meta-llama/Llama-2-7b-hf \
Expand All @@ -316,7 +299,7 @@ python3 -m models.llama.benchmark \

6. ONNX Runtime, FP16, Microsoft custom export
```
python3 -m models.llama.benchmark \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark \
--benchmark-type ort-msft \
--ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \
--model-name meta-llama/Llama-2-7b-hf \
Expand Down Expand Up @@ -367,7 +350,7 @@ You can profile a variant by adding the `--profile` flag and providing one batch
### Benchmark All
You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example.
```
python3 -m models.llama.benchmark_all \
CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.benchmark_all \
--hf-pt-eager \
--hf-pt-compile \
--hf-ort-dir-path ./llama2-7b-fp16/ \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import os
import shutil
import subprocess
import sys
from itertools import chain

import onnx
Expand Down Expand Up @@ -408,6 +410,31 @@ def optimize_export(config: AutoConfig, input_path: str, output_path: str, remov
only_onnxruntime=False,
)
model_opt.save_model_to_file(output_path, use_external_data_format=True)

# Run symbolic shape inference on optimized model to avoid shape errors during runtime
# Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
# After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
symbolic_shape_infer_args = [
"--input",
output_path,
"--output",
output_path,
"--auto_merge",
"--save_as_external_data",
"--all_tensors_to_one_file",
"--external_data_location",
os.path.basename(output_path) + ".data",
]

file_path = os.path.dirname(__file__)
if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
main_cmd = wheel_cmd
else:
main_cmd = source_cmd
subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510

logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
if remove_model:
remove_existing_model(input_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32,
if i == rank % (world_size):
l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=use_auth_token,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
git+https://github.com/huggingface/optimum.git
optimum>=1.14.1
transformers>=4.33.2
torch>=2.2.0.dev20230920
onnx>=1.14.0
Expand Down

0 comments on commit 5eebd09

Please sign in to comment.