Skip to content

Commit

Permalink
add new pattern for eff attn
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Oct 23, 2023
1 parent 1194203 commit 313a5c6
Showing 1 changed file with 159 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def _make_efficient_attention_nodes(
expand_bias: bool,
scale: float,
dropout_ratio: float,
causal: bool,
):
nodes_to_add = []
scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale])
dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio])
causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0])
int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0])
true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True])
false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False])
Expand Down Expand Up @@ -70,7 +72,7 @@ def _make_efficient_attention_nodes(
"",
"",
dropout_ratio_node.output[0],
int_zero_node.output[0],
causal_node.output[0],
true_node.output[0],
scale_node.output[0],
"",
Expand Down Expand Up @@ -99,7 +101,7 @@ def _make_efficient_attention_nodes(
dropout_ratio_node.output[0],
seed.name,
offset.name,
int_zero_node.output[0],
causal_node.output[0],
false_node.output[0],
scale_node.output[0],
"",
Expand All @@ -110,7 +112,9 @@ def _make_efficient_attention_nodes(
"org.pytorch.aten",
operator="_efficient_attention_backward",
)
nodes_to_add.extend([scale_node, dropout_ratio_node, int_zero_node, true_node, false_node, fwd_node, bwd_node])
nodes_to_add.extend(
[scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node]
)
return nodes_to_add, new_value_infos


Expand Down Expand Up @@ -172,6 +176,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro
add_input_shape_0 != add_input_shape_1,
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
ratio_value,
False,
)
return nodes, nodes_to_add, new_value_infos

Expand Down Expand Up @@ -230,13 +235,164 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro
add_input_shape_0 != add_input_shape_1,
1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value),
0.0,
False,
)
return nodes, nodes_to_add, new_value_infos


# No causal mask, no attention mask, without Dropout.
_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Softmax", False, [(0, 0, 0)]), # 7
("Cast", False, [(7, 0, 0)]), # 8
("MatMul", False, [(8, 0, 0)]), # 9
("Transpose", True, [(9, 0, 1)]), # 10
("Transpose", False, [(9, 0, 0)]), # 11
("FusedMatMul", False, [(10, 0, 1)]), # 12
("Cast", False, [(12, 0, 0)]), # 13
("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14
("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15
("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16
("Mul", False, [(15, 0, 0)]), # 17
("Mul", False, [(16, 0, 0)]), # 18
("Identity", False, [(17, 0, 0)]), # 19
("Identity", False, [(18, 0, 0)]), # 20
("Cast", False, [(19, 0, 0)]), # 21
("Cast", False, [(20, 0, 0)]), # 22
("Transpose", False, [(21, 0, 0)]), # 23
("Transpose", False, [(22, 0, 0)]), # 24
("FusedMatMul", False, [(8, 0, 0)]), # 25
("Transpose", True, [(25, 0, 1)]), # 26
("Transpose", False, [(25, 0, 0)]), # 27
]


def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[8], "to", 10)
and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []

nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[10].input[0],
nodes[11].output[0],
nodes[26].input[0],
nodes[23].output[0],
nodes[24].output[0],
nodes[27].output[0],
"",
False,
scale_value_1,
0.0,
False,
)
return nodes, nodes_to_add, new_value_infos


# Has causal mask, no attention mask, without Dropout.
_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [
("MatMul", False, []), # 0
("Mul", True, [(0, 0, 0)]), # 1
("Mul", True, [(0, 0, 1)]), # 2
("Cast", True, [(1, 0, 0)]), # 3
("Cast", True, [(2, 0, 0)]), # 4
("Transpose", True, [(3, 0, 0)]), # 5
("Transpose", True, [(4, 0, 0)]), # 6
("Add", False, [(0, 0, 0)]), # 7
("Cast", True, [(7, 0, 1)]), # 8
("Slice", True, [(8, 0, 0)]), # 9
("Slice", True, [(9, 0, 0)]), # 10
("Unsqueeze", True, [(9, 0, 2)]), # 11
("Gather", True, [(11, 0, 0)]), # 12
("Shape", True, [(12, 0, 0)]), # 13
("Softmax", False, [(7, 0, 0)]), # 14
("Cast", False, [(14, 0, 0)]), # 15
("MatMul", False, [(15, 0, 0)]), # 16
("Transpose", True, [(16, 0, 1)]), # 17
("Transpose", False, [(16, 0, 0)]), # 18
("FusedMatMul", False, [(17, 0, 1)]), # 19
("Cast", False, [(19, 0, 0)]), # 20
("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21
("Identity", False, [(21, 0, 0)]), # 22
("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23
("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24
("Mul", False, [(23, 0, 0)]), # 25
("Mul", False, [(24, 0, 0)]), # 26
("Identity", False, [(25, 0, 0)]), # 27
("Identity", False, [(26, 0, 0)]), # 28
("Cast", False, [(27, 0, 0)]), # 29
("Cast", False, [(28, 0, 0)]), # 30
("Transpose", False, [(29, 0, 0)]), # 31
("Transpose", False, [(30, 0, 0)]), # 32
("FusedMatMul", False, [(15, 0, 0)]), # 33
("Transpose", True, [(33, 0, 1)]), # 34
("Transpose", False, [(33, 0, 0)]), # 35
]


def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]):
# Check forward only as the backward is expected to be consistent if it's built correctly.
scale_value_1 = matcher.get_constant_value(nodes[1].input[1])
scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1
scale_value_2 = matcher.get_constant_value(nodes[2].input[1])
scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2
if not (
check_attribute_value(nodes[3], "to", 1)
and check_attribute_value(nodes[4], "to", 1)
and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1])
and check_attribute_value(nodes[15], "to", 10)
and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3])
and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3])
and scale_value_1 == scale_value_2
):
return [], [], []

nodes_to_add, new_value_infos = _make_efficient_attention_nodes(
idx,
nodes[5].input[0],
nodes[6].input[0],
nodes[17].input[0],
nodes[18].output[0],
nodes[34].input[0],
nodes[31].output[0],
nodes[32].output[0],
nodes[35].output[0],
"",
False,
scale_value_1,
0.0,
True,
)
return nodes, nodes_to_add, new_value_infos


_PATTERNS = [
(_PATTERN_0, _optimize_for_pattern_0),
(_PATTERN_1, _optimize_for_pattern_1),
(_PATTERN_2, _optimize_for_pattern_2),
(_PATTERN_3, _optimize_for_pattern_3),
]


Expand Down

0 comments on commit 313a5c6

Please sign in to comment.