Skip to content

Commit

Permalink
[ORTModule] Adjust Attention Patterns for Efficient Attention ATen Fa…
Browse files Browse the repository at this point in the history
…llback (#18471)

Adjust attention patterns to match latest Whisper+exporter. Also add
some condition check and add docs.
  • Loading branch information
centwang authored Nov 22, 2023
1 parent 7c57305 commit 3bc9efc
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 90 deletions.
18 changes: 18 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ data sparsity based performance optimizations.
unset ORTMODULE_CACHE_DIR # Disable
```

#### ORTMODULE_USE_EFFICIENT_ATTENTION

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.

```bash
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down Expand Up @@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training
export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results
```

#### ORTMODULE_USE_FLASH_ATTENTION

- **Feature Area**: *ORTMODULE/TritonOp*
- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually.

```bash
export ORTMODULE_USE_FLASH_ATTENTION=1
```

#### ORTMODULE_TRITON_DEBUG

- **Feature Area**: *ORTMODULE/TritonOp*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import os

import torch

from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401
from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401

Expand All @@ -17,7 +19,12 @@
"slice_scel_backward",
]

if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1:
if (
"ORTMODULE_USE_FLASH_ATTENTION" in os.environ
and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8
):
from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401

_all_kernels.extend(["flash_attn_forward", "flash_attn_backward"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

import os

import torch
from packaging.version import Version

_all_optimizers = []

if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1:
if (
"ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ
and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1
and Version(torch.__version__) >= Version("2.1.1")
):
from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401

_all_optimizers.append("optimize_graph_for_aten_efficient_attention")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,31 +245,25 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Softmax", False, [(0, 0, 0)]), # 7
("Cast", False, [(7, 0, 0)]), # 8
("MatMul", False, [(8, 0, 0)]), # 9
("Transpose", True, [(9, 0, 1)]), # 10
("Transpose", False, [(9, 0, 0)]), # 11
("FusedMatMul", False, [(10, 0, 1)]), # 12
("Cast", False, [(12, 0, 0)]), # 13
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
("Mul", False, [(15, 0, 0)]), # 17
("Mul", False, [(16, 0, 0)]), # 18
("Identity", False, [(17, 0, 0)]), # 19
("Identity", False, [(18, 0, 0)]), # 20
("Cast", False, [(19, 0, 0)]), # 21
("Cast", False, [(20, 0, 0)]), # 22
("Transpose", False, [(21, 0, 0)]), # 23
("Transpose", False, [(22, 0, 0)]), # 24
("FusedMatMul", False, [(8, 0, 0)]), # 25
("Transpose", True, [(25, 0, 1)]), # 26
("Transpose", False, [(25, 0, 0)]), # 27
("Transpose", True, [(1, 0, 0)]), # 3
("Transpose", True, [(2, 0, 0)]), # 4
("Softmax", False, [(0, 0, 0)]), # 5
("MatMul", False, [(5, 0, 0)]), # 6
("Transpose", True, [(6, 0, 1)]), # 7
("Transpose", False, [(6, 0, 0)]), # 8
("FusedMatMul", False, [(7, 0, 1)]), # 9
("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10
("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11
("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12
("Mul", False, [(11, 0, 0)]), # 13
("Mul", False, [(12, 0, 0)]), # 14
("Identity", False, [(13, 0, 0)]), # 15
("Identity", False, [(14, 0, 0)]), # 16
("Transpose", False, [(15, 0, 0)]), # 17
("Transpose", False, [(16, 0, 0)]), # 18
("FusedMatMul", False, [(5, 0, 0)]), # 19
("Transpose", True, [(19, 0, 1)]), # 20
("Transpose", False, [(19, 0, 0)]), # 21
]


Expand All @@ -280,27 +274,24 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[8], "to", 10)
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []

nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[10].input[0],
nodes[11].output[0],
nodes[26].input[0],
nodes[23].output[0],
nodes[24].output[0],
nodes[27].output[0],
nodes[3].input[0],
nodes[4].input[0],
nodes[7].input[0],
nodes[8].output[0],
nodes[20].input[0],
nodes[17].output[0],
nodes[18].output[0],
nodes[21].output[0],
"",
False,
scale_value_1,
Expand All @@ -315,39 +306,32 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Add", False, [(0, 0, 0)]), # 7
("Cast", True, [(7, 0, 1)]), # 8
("Slice", True, [(8, 0, 0)]), # 9
("Slice", True, [(9, 0, 0)]), # 10
("Unsqueeze", True, [(9, 0, 2)]), # 11
("Gather", True, [(11, 0, 0)]), # 12
("Shape", True, [(12, 0, 0)]), # 13
("Softmax", False, [(7, 0, 0)]), # 14
("Cast", False, [(14, 0, 0)]), # 15
("MatMul", False, [(15, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(16, 0, 0)]), # 18
("FusedMatMul", False, [(17, 0, 1)]), # 19
("Cast", False, [(19, 0, 0)]), # 20
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
("Identity", False, [(21, 0, 0)]), # 22
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
("Mul", False, [(23, 0, 0)]), # 25
("Mul", False, [(24, 0, 0)]), # 26
("Identity", False, [(25, 0, 0)]), # 27
("Identity", False, [(26, 0, 0)]), # 28
("Cast", False, [(27, 0, 0)]), # 29
("Cast", False, [(28, 0, 0)]), # 30
("Transpose", False, [(29, 0, 0)]), # 31
("Transpose", False, [(30, 0, 0)]), # 32
("FusedMatMul", False, [(15, 0, 0)]), # 33
("Transpose", True, [(33, 0, 1)]), # 34
("Transpose", False, [(33, 0, 0)]), # 35
("Transpose", True, [(1, 0, 0)]), # 3
("Transpose", True, [(2, 0, 0)]), # 4
("Add", False, [(0, 0, 0)]), # 5
("Slice", True, [(5, 0, 1)]), # 6
("Slice", True, [(6, 0, 0)]), # 7
("Unsqueeze", True, [(6, 0, 2)]), # 8
("Gather", True, [(8, 0, 0)]), # 9
("Shape", True, [(9, 0, 0)]), # 10
("Softmax", False, [(5, 0, 0)]), # 11
("MatMul", False, [(11, 0, 0)]), # 12
("Transpose", True, [(12, 0, 1)]), # 13
("Transpose", False, [(12, 0, 0)]), # 14
("FusedMatMul", False, [(13, 0, 1)]), # 15
("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16
("Identity", False, [(16, 0, 0)]), # 17
("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18
("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19
("Mul", False, [(18, 0, 0)]), # 20
("Mul", False, [(19, 0, 0)]), # 21
("Identity", False, [(20, 0, 0)]), # 22
("Identity", False, [(21, 0, 0)]), # 23
("Transpose", False, [(22, 0, 0)]), # 24
("Transpose", False, [(23, 0, 0)]), # 25
("FusedMatMul", False, [(11, 0, 0)]), # 26
("Transpose", True, [(26, 0, 1)]), # 27
("Transpose", False, [(26, 0, 0)]), # 28
]


Expand All @@ -358,27 +342,24 @@ def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodePro
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[15], "to", 10)
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
check_attribute_value(nodes[3], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []

nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[17].input[0],
nodes[18].output[0],
nodes[34].input[0],
nodes[31].output[0],
nodes[32].output[0],
nodes[35].output[0],
nodes[3].input[0],
nodes[4].input[0],
nodes[13].input[0],
nodes[14].output[0],
nodes[27].input[0],
nodes[24].output[0],
nodes[25].output[0],
nodes[28].output[0],
"",
False,
scale_value_1,
Expand Down

0 comments on commit 3bc9efc

Please sign in to comment.