diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index d42af92c7c66d..1f65d886a4b8b 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -224,8 +224,10 @@ void IterateSubgraphFromNode(Graph& graph, visited.insert(cur); if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Add", {7, 13, 14}) || graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain) || - graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14}) || - graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14})) { + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Div", {7, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Pow", {7, 12, 13, 15}) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14})) { ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end() || subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end()); if (cur->InputDefs()[0]->Shape() && cur->InputDefs()[1]->Shape()) { @@ -278,7 +280,10 @@ void IterateSubgraphFromNode(Graph& graph, subgraph.insert(cur->MutableOutputDefs()[1]); PushAllOutputNode(graph, to_visit, cur, visited); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) || - graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain)) { + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "FastGelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "QuickGelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sqrt", {6, 13})) { ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()); subgraph.insert(cur->MutableOutputDefs()[0]); PushAllOutputNode(graph, to_visit, cur, visited); diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index da217eb76949c..5078058995281 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5725,8 +5725,6 @@ def run_step(model, input, target): @pytest.mark.parametrize("label_is_sparse", [False, True]) @pytest.mark.parametrize("rank", [1, 2]) def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, label_is_sparse, rank, caplog): - os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1" - class NeuralNetCrossEntropyLoss(torch.nn.Module): def __init__(self, num_embeddings, embedding_dim): super().__init__() @@ -5797,10 +5795,12 @@ def run_step(model, input, positions): "test_cases", [ ("Add", 0), + ("Add", 1), ("Add", 2), ("Add", 3), ("Add", 4), ("Sub", 0), + ("Sub", 1), ("Sub", 2), ("Sub", 3), ("Sub", 4), @@ -5808,12 +5808,22 @@ def run_step(model, input, positions): ("Mul", 2), ("Mul", 3), ("Mul", 4), + ("Div", 0), + ("Div", 2), + ("Div", 3), + ("Div", 4), + ("Pow", 0), + ("Pow", 1), + ("Pow", 2), + ("Pow", 3), + ("Pow", 4), ("MatMul", 0), ("MatMul", 1), ("Dropout", 0), ("LayerNormalization", 0), ("LayerNormalization", 1), ("Cast", 0), + ("Sqrt", 0), ("BiasGelu", 0), ("Gelu", 0), ("ReduceMean", 0), @@ -5821,7 +5831,6 @@ def run_step(model, input, positions): ], ) def test_ops_for_padding_elimination(test_cases): - os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1" test_op = test_cases[0] case = test_cases[1] @@ -5848,7 +5857,7 @@ def __init__(self, vocab_size, hidden_size, pad_token_id): # pattern should be insert to the arg of [batch_size, 1, hidden_size]. # in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size], # the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad' - # pattern should be insert to the arg of [batch_size, 1, hidden_size]. + # pattern should be insert to the arg of [1, hidden_size]. # in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size], # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to # output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to @@ -5858,6 +5867,8 @@ def test_elementwise(self, input_ids): one_input = None if case == 0: one_input = torch.ones(self.hidden_size, dtype=torch.long).to(device) + elif case == 1: + one_input = 1 elif case == 2: one_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device) elif case == 3: @@ -5872,6 +5883,10 @@ def test_elementwise(self, input_ids): output = one_input - inputs_embeds elif test_op == "Mul": output = one_input * inputs_embeds + elif test_op == "Div": + output = inputs_embeds / one_input + elif test_op == "Pow": + output = inputs_embeds ** (one_input * 2) else: output = None return output @@ -5911,6 +5926,8 @@ def test_other(self, input_ids): output = torch.nn.functional.gelu(inputs_embeds + bias) elif test_op == "Gelu": output = torch.nn.functional.gelu(inputs_embeds) + elif test_op == "Sqrt": + output = torch.sqrt(inputs_embeds) elif test_op == "ReduceMean": # In case 0, the inputs_embeds are reduced at last dimension, the ReduceMean should be included in padding # elimination subgraph and the PadAndUnflatten should be added to output of ReduceMean. @@ -5924,7 +5941,7 @@ def test_other(self, input_ids): return output def forward(self, input_ids): - if test_op in ["Add", "Mul", "Sub"]: + if test_op in ["Add", "Mul", "Sub", "Div", "Pow"]: output = self.test_elementwise(input_ids) elif test_op == "MatMul": output = self.test_matmul(input_ids) @@ -5953,7 +5970,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): model(x) training_model = model._torch_module._execution_manager(True)._onnx_models.optimized_model - if test_op == "Sub": + if test_op == "Sub" or test_op == "Pow": assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 2 else: assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1 @@ -5974,7 +5991,7 @@ def find_input_node_type(model, arg): return result[0].op_type if len(result) == 1 else None recover_pad_input_optypes = [find_input_node_type(training_model, arg) for arg in recover_pad_node.input] - if test_op == "Add" or test_op == "Mul" or test_op == "Sub": + if test_op == "Add" or test_op == "Mul" or test_op == "Sub" or test_op == "Div" or test_op == "Pow": assert test_op in recover_pad_input_optypes else: if case == 0: @@ -5982,11 +5999,8 @@ def find_input_node_type(model, arg): else: assert "ATen" in recover_pad_input_optypes - del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] - def test_e2e_padding_elimination(): - os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1" seed = 5033 random.seed(seed) np.random.seed(seed) @@ -6129,7 +6143,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node] assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node] - del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] @pytest.mark.skipif(