Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Pow/Div/Sqrt in PaddingElimination #20083

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ void IterateSubgraphFromNode(Graph& graph,
visited.insert(cur);
if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Add", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "BiasGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14})) {
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Div", {7, 13, 14}) ||
guyang3532 marked this conversation as resolved.
Show resolved Hide resolved
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Mul", {7, 13, 14}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Pow", {7, 12, 13, 15}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sub", {7, 13, 14})) {
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end() ||
subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end());
if (cur->InputDefs()[0]->Shape() && cur->InputDefs()[1]->Shape()) {
Expand Down Expand Up @@ -278,7 +280,10 @@ void IterateSubgraphFromNode(Graph& graph,
subgraph.insert(cur->MutableOutputDefs()[1]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Cast", {9, 13}) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain)) {
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "FastGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "QuickGelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "Sqrt", {6, 13})) {
ORT_ENFORCE(subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end());
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5725,8 +5725,6 @@ def run_step(model, input, target):
@pytest.mark.parametrize("label_is_sparse", [False, True])
@pytest.mark.parametrize("rank", [1, 2])
def test_runtime_inspector_label_and_embed_sparsity_detection(embed_is_sparse, label_is_sparse, rank, caplog):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"

class NeuralNetCrossEntropyLoss(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
Expand Down Expand Up @@ -5797,31 +5795,42 @@ def run_step(model, input, positions):
"test_cases",
[
("Add", 0),
("Add", 1),
("Add", 2),
("Add", 3),
("Add", 4),
("Sub", 0),
("Sub", 1),
("Sub", 2),
("Sub", 3),
("Sub", 4),
("Mul", 0),
("Mul", 2),
("Mul", 3),
("Mul", 4),
("Div", 0),
("Div", 2),
("Div", 3),
("Div", 4),
("Pow", 0),
("Pow", 1),
("Pow", 2),
("Pow", 3),
("Pow", 4),
("MatMul", 0),
("MatMul", 1),
("Dropout", 0),
("LayerNormalization", 0),
("LayerNormalization", 1),
("Cast", 0),
("Sqrt", 0),
("BiasGelu", 0),
("Gelu", 0),
("ReduceMean", 0),
("ReduceMean", 1),
],
)
def test_ops_for_padding_elimination(test_cases):
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
test_op = test_cases[0]
case = test_cases[1]

Expand All @@ -5848,7 +5857,7 @@ def __init__(self, vocab_size, hidden_size, pad_token_id):
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# pattern should be insert to the arg of [1, hidden_size].
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to
Expand All @@ -5858,6 +5867,8 @@ def test_elementwise(self, input_ids):
one_input = None
if case == 0:
one_input = torch.ones(self.hidden_size, dtype=torch.long).to(device)
elif case == 1:
one_input = 1
elif case == 2:
one_input = torch.ones((input_shape[0], 1, self.hidden_size), dtype=torch.long).to(device)
elif case == 3:
Expand All @@ -5872,6 +5883,10 @@ def test_elementwise(self, input_ids):
output = one_input - inputs_embeds
elif test_op == "Mul":
output = one_input * inputs_embeds
elif test_op == "Div":
output = inputs_embeds / one_input
elif test_op == "Pow":
output = inputs_embeds ** (one_input * 2)
else:
output = None
return output
Expand Down Expand Up @@ -5911,6 +5926,8 @@ def test_other(self, input_ids):
output = torch.nn.functional.gelu(inputs_embeds + bias)
elif test_op == "Gelu":
output = torch.nn.functional.gelu(inputs_embeds)
elif test_op == "Sqrt":
output = torch.sqrt(inputs_embeds)
elif test_op == "ReduceMean":
# In case 0, the inputs_embeds are reduced at last dimension, the ReduceMean should be included in padding
# elimination subgraph and the PadAndUnflatten should be added to output of ReduceMean.
Expand All @@ -5924,7 +5941,7 @@ def test_other(self, input_ids):
return output

def forward(self, input_ids):
if test_op in ["Add", "Mul", "Sub"]:
if test_op in ["Add", "Mul", "Sub", "Div", "Pow"]:
output = self.test_elementwise(input_ids)
elif test_op == "MatMul":
output = self.test_matmul(input_ids)
Expand Down Expand Up @@ -5953,7 +5970,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
model(x)

training_model = model._torch_module._execution_manager(True)._onnx_models.optimized_model
if test_op == "Sub":
if test_op == "Sub" or test_op == "Pow":
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 2
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1
Expand All @@ -5974,19 +5991,16 @@ def find_input_node_type(model, arg):
return result[0].op_type if len(result) == 1 else None

recover_pad_input_optypes = [find_input_node_type(training_model, arg) for arg in recover_pad_node.input]
if test_op == "Add" or test_op == "Mul" or test_op == "Sub":
if test_op == "Add" or test_op == "Mul" or test_op == "Sub" or test_op == "Div" or test_op == "Pow":
assert test_op in recover_pad_input_optypes
else:
if case == 0:
assert test_op in recover_pad_input_optypes
else:
assert "ATen" in recover_pad_input_optypes

del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]


def test_e2e_padding_elimination():
os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] = "1"
seed = 5033
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -6129,7 +6143,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]


@pytest.mark.skipif(
Expand Down
Loading