diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index 843f0d32dd65f..94bd41293b427 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -34,10 +34,12 @@ def _make_efficient_attention_nodes( expand_bias: bool, scale: float, dropout_ratio: float, + causal: bool, ): nodes_to_add = [] scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) + causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) @@ -70,7 +72,7 @@ def _make_efficient_attention_nodes( "", "", dropout_ratio_node.output[0], - int_zero_node.output[0], + causal_node.output[0], true_node.output[0], scale_node.output[0], "", @@ -99,7 +101,7 @@ def _make_efficient_attention_nodes( dropout_ratio_node.output[0], seed.name, offset.name, - int_zero_node.output[0], + causal_node.output[0], false_node.output[0], scale_node.output[0], "", @@ -110,7 +112,9 @@ def _make_efficient_attention_nodes( "org.pytorch.aten", operator="_efficient_attention_backward", ) - nodes_to_add.extend([scale_node, dropout_ratio_node, int_zero_node, true_node, false_node, fwd_node, bwd_node]) + nodes_to_add.extend( + [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] + ) return nodes_to_add, new_value_infos @@ -172,6 +176,7 @@ def _optimize_for_pattern_0(matcher: GraphMatcher, idx: int, nodes: List[NodePro add_input_shape_0 != add_input_shape_1, 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), ratio_value, + False, ) return nodes, nodes_to_add, new_value_infos @@ -230,6 +235,155 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro add_input_shape_0 != add_input_shape_1, 1 / float(scale_value[0] if isinstance(scale_value, list) else scale_value), 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# No causal mask, no attention mask, without Dropout. +_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Softmax", False, [(0, 0, 0)]), # 7 + ("Cast", False, [(7, 0, 0)]), # 8 + ("MatMul", False, [(8, 0, 0)]), # 9 + ("Transpose", True, [(9, 0, 1)]), # 10 + ("Transpose", False, [(9, 0, 0)]), # 11 + ("FusedMatMul", False, [(10, 0, 1)]), # 12 + ("Cast", False, [(12, 0, 0)]), # 13 + ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 + ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 + ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 + ("Mul", False, [(15, 0, 0)]), # 17 + ("Mul", False, [(16, 0, 0)]), # 18 + ("Identity", False, [(17, 0, 0)]), # 19 + ("Identity", False, [(18, 0, 0)]), # 20 + ("Cast", False, [(19, 0, 0)]), # 21 + ("Cast", False, [(20, 0, 0)]), # 22 + ("Transpose", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("FusedMatMul", False, [(8, 0, 0)]), # 25 + ("Transpose", True, [(25, 0, 1)]), # 26 + ("Transpose", False, [(25, 0, 0)]), # 27 +] + + +def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[8], "to", 10) + and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[10].input[0], + nodes[11].output[0], + nodes[26].input[0], + nodes[23].output[0], + nodes[24].output[0], + nodes[27].output[0], + "", + False, + scale_value_1, + 0.0, + False, + ) + return nodes, nodes_to_add, new_value_infos + + +# Has causal mask, no attention mask, without Dropout. +_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ + ("MatMul", False, []), # 0 + ("Mul", True, [(0, 0, 0)]), # 1 + ("Mul", True, [(0, 0, 1)]), # 2 + ("Cast", True, [(1, 0, 0)]), # 3 + ("Cast", True, [(2, 0, 0)]), # 4 + ("Transpose", True, [(3, 0, 0)]), # 5 + ("Transpose", True, [(4, 0, 0)]), # 6 + ("Add", False, [(0, 0, 0)]), # 7 + ("Cast", True, [(7, 0, 1)]), # 8 + ("Slice", True, [(8, 0, 0)]), # 9 + ("Slice", True, [(9, 0, 0)]), # 10 + ("Unsqueeze", True, [(9, 0, 2)]), # 11 + ("Gather", True, [(11, 0, 0)]), # 12 + ("Shape", True, [(12, 0, 0)]), # 13 + ("Softmax", False, [(7, 0, 0)]), # 14 + ("Cast", False, [(14, 0, 0)]), # 15 + ("MatMul", False, [(15, 0, 0)]), # 16 + ("Transpose", True, [(16, 0, 1)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(17, 0, 1)]), # 19 + ("Cast", False, [(19, 0, 0)]), # 20 + ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 + ("Identity", False, [(21, 0, 0)]), # 22 + ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 + ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 + ("Mul", False, [(23, 0, 0)]), # 25 + ("Mul", False, [(24, 0, 0)]), # 26 + ("Identity", False, [(25, 0, 0)]), # 27 + ("Identity", False, [(26, 0, 0)]), # 28 + ("Cast", False, [(27, 0, 0)]), # 29 + ("Cast", False, [(28, 0, 0)]), # 30 + ("Transpose", False, [(29, 0, 0)]), # 31 + ("Transpose", False, [(30, 0, 0)]), # 32 + ("FusedMatMul", False, [(15, 0, 0)]), # 33 + ("Transpose", True, [(33, 0, 1)]), # 34 + ("Transpose", False, [(33, 0, 0)]), # 35 +] + + +def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): + # Check forward only as the backward is expected to be consistent if it's built correctly. + scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) + scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 + scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) + scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 + if not ( + check_attribute_value(nodes[3], "to", 1) + and check_attribute_value(nodes[4], "to", 1) + and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[15], "to", 10) + and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + and scale_value_1 == scale_value_2 + ): + return [], [], [] + + nodes_to_add, new_value_infos = _make_efficient_attention_nodes( + idx, + nodes[5].input[0], + nodes[6].input[0], + nodes[17].input[0], + nodes[18].output[0], + nodes[34].input[0], + nodes[31].output[0], + nodes[32].output[0], + nodes[35].output[0], + "", + False, + scale_value_1, + 0.0, + True, ) return nodes, nodes_to_add, new_value_infos @@ -237,6 +391,8 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro _PATTERNS = [ (_PATTERN_0, _optimize_for_pattern_0), (_PATTERN_1, _optimize_for_pattern_1), + (_PATTERN_2, _optimize_for_pattern_2), + (_PATTERN_3, _optimize_for_pattern_3), ]