Skip to content

Commit

Permalink
support Pow/Div/Sqrt in PaddingElimination
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Mar 27, 2024
1 parent 28907d8 commit 1341182
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ void IterateSubgraphFromNode(Graph& graph,
visited.insert(cur);
if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Add", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14})) {
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Div", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Pow", {7, 12, 13, 15}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14})) {
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end() ||
subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end());
if (cur->InputDefs()[0]->Shape() && cur->InputDefs()[1]->Shape()) {
Expand Down Expand Up @@ -278,7 +280,10 @@ void IterateSubgraphFromNode(Graph& graph,
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain)) {
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "FastGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "QuickGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sqrt", {6, 13})) {
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5725,8 +5725,6 @@ def run_step(model, input, target):
@pytest.mark.parametrize("label_is_sparse", [False, True])
@pytest.mark.parametrize("rank", [1, 2])
def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, label_is_sparse, rank, caplog):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"

class NeuralNetCrossEntropyLoss(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
Expand Down Expand Up @@ -5797,31 +5795,42 @@ def run_step(model, input, positions):
"test_cases",
[
("Add", 0),
("Add", 1),
("Add", 2),
("Add", 3),
("Add", 4),
("Sub", 0),
("Sub", 1),
("Sub", 2),
("Sub", 3),
("Sub", 4),
("Mul", 0),
("Mul", 2),
("Mul", 3),
("Mul", 4),
("Div", 0),
("Div", 2),
("Div", 3),
("Div", 4),
("Pow", 0),
("Pow", 1),
("Pow", 2),
("Pow", 3),
("Pow", 4),
("MatMul", 0),
("MatMul", 1),
("Dropout", 0),
("LayerNormalization", 0),
("LayerNormalization", 1),
("Cast", 0),
("Sqrt", 0),
("BiasGelu", 0),
("Gelu", 0),
("ReduceMean", 0),
("ReduceMean", 1),
],
)
def test_ops_for_padding_elimination(test_cases):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
test_op = test_cases[0]
case = test_cases[1]

Expand All @@ -5848,7 +5857,7 @@ def __init__(self, vocab_size, hidden_size, pad_token_id):
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# pattern should be insert to the arg of [1, hidden_size].
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to
Expand All @@ -5858,6 +5867,8 @@ def test_elementwise(self, input_ids):
one_input = None
if case == 0:
one_input = torch.ones(self.hidden_size, dtype=torch.long).to(device)
elif case == 1:
one_input = 1
elif case == 2:
one_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device)
elif case == 3:
Expand All @@ -5872,6 +5883,10 @@ def test_elementwise(self, input_ids):
output = one_input - inputs_embeds
elif test_op == "Mul":
output = one_input * inputs_embeds
elif test_op == "Div":
output = inputs_embeds / one_input
elif test_op == "Pow":
output = inputs_embeds ** (one_input * 2)
else:
output = None
return output
Expand Down Expand Up @@ -5911,6 +5926,8 @@ def test_other(self, input_ids):
output = torch.nn.functional.gelu(inputs_embeds + bias)
elif test_op == "Gelu":
output = torch.nn.functional.gelu(inputs_embeds)
elif test_op == "Sqrt":
output = torch.sqrt(inputs_embeds)
elif test_op == "ReduceMean":
# In case 0, the inputs_embeds are reduced at last dimension, the ReduceMean should be included in padding
# elimination subgraph and the PadAndUnflatten should be added to output of ReduceMean.
Expand All @@ -5924,7 +5941,7 @@ def test_other(self, input_ids):
return output

def forward(self, input_ids):
if test_op in ["Add", "Mul", "Sub"]:
if test_op in ["Add", "Mul", "Sub", "Div", "Pow"]:
output = self.test_elementwise(input_ids)
elif test_op == "MatMul":
output = self.test_matmul(input_ids)
Expand Down Expand Up @@ -5953,7 +5970,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
model(x)

training_model = model._torch_module._execution_manager(True)._onnx_models.optimized_model
if test_op == "Sub":
if test_op == "Sub" or test_op == "Pow":
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 2
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
Expand All @@ -5974,19 +5991,16 @@ def find_input_node_type(model, arg):
return result[0].op_type if len(result) == 1 else None

recover_pad_input_optypes = [find_input_node_type(training_model, arg) for arg in recover_pad_node.input]
if test_op == "Add" or test_op == "Mul" or test_op == "Sub":
if test_op == "Add" or test_op == "Mul" or test_op == "Sub" or test_op == "Div" or test_op == "Pow":
assert test_op in recover_pad_input_optypes
else:
if case == 0:
assert test_op in recover_pad_input_optypes
else:
assert "ATen" in recover_pad_input_optypes

del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]


def test_e2e_padding_elimination():
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
seed = 5033
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -6129,7 +6143,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]


@pytest.mark.skipif(
Expand Down

0 comments on commit 1341182

Please sign in to comment.