-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update LLaMA attention fusions (#19200)
### Description This PR updates the LLaMA-2 attention fusions by adding the following. - Loading the PyTorch model from Hugging Face with the `LlamaAttention` class before exporting - Updating the attention mask pattern matching to support another case This PR also fixes [this issue](#19040). ### Motivation and Context Recent changes to Hugging Face's `transformers` library break the existing pattern matching. Since the attention fusions aim to change the graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to `LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately it does not matter what nodes comprise the `Set of Attention Nodes` because they will all be removed and replaced by the `Attention Op` in the end. Therefore, it does not matter whether the `LlamaAttention` class or a different attention class is used to load the PyTorch model before exporting because the expected graphs after the attention fusions will look identical no matter the attention class chosen. By loading the PyTorch model with the `LlamaAttention` class instead of other attention classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then exporting it to ONNX, the existing pattern matching will continue to work.
- Loading branch information
1 parent
2b86515
commit 5eebd09
Showing
5 changed files
with
46 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
onnxruntime/python/tools/transformers/models/llama/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters