Skip to content

Commit

Permalink
Add FlattenAndUnpad Op (microsoft#17845)
Browse files Browse the repository at this point in the history
### Description
Add an op named `FlattenAndUnpad`.
This op implements functions:
1. Flatten the first two dims of input tensor.
2. Gather valid value from input tensor with index tensor,.


### Motivation and Context
The grad op of `PadAndUnflatten` was `GatherGrad` which is inefficient
in performance.
I implement this `FlattenAndUnpad` just to replace the `GatherGrad` as
grad of `PadAndUnflatten`.
With this op, we also can simplify the "Reshape + ShrunkenGather"
pattern to `PadAndUnflatten` in padding elimination optimizer, which
will also improve performance.
  • Loading branch information
guyang3532 authored and kleiti committed Mar 22, 2024
1 parent b75ebe5 commit 8a509ed
Show file tree
Hide file tree
Showing 17 changed files with 448 additions and 118 deletions.
17 changes: 10 additions & 7 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {

IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
{GO(0), I(1)},
{GI(0), IA("Unflatten_dims")})};
}

IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1},
{GO(0), I(1), O(1)},
{GI(0)})};
}

IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
Expand Down
26 changes: 22 additions & 4 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4741,22 +4741,20 @@ Return true if all elements are true and false otherwise.
"For other indices, the corresponding value in output will be padded to zero."

"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
"(in case of performance concern), the behaviour of output is undefined."
"(in case of performance concern), the behavior of output is undefined."

"An example:"
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
" indices: [0, 5], shape is [2]"
" unflatten_dims: [2, 3], shape is [2]"

" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
" shape is [2, 3, 4]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
Expand All @@ -4770,6 +4768,26 @@ Return true if all elements are true and false otherwise.
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");

ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
"This is used by padding elimination graph transformer.")
.Input(0, "input", "input data of rank N + 1, shape is [M1, M2, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INT")
.Output(0, "output", "output data of rank N, [d1, d2, ..., dN]", "T")
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices and shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.");

ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph,
return new_expand_node->MutableOutputDefs()[0];
}

// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// Insert FlattenAndUnpad to flatten and unpad the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertFlattenPatternForInput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gather_index_arg,
const logging::Logger& logger) {
InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.reserve(2);
reshape_input_args.push_back(node.MutableInputDefs()[in_index]);
std::vector<int64_t> new_shape;
new_shape.push_back(-1); // only support flatten 0 and 1 dims
auto input_shape = node.InputDefs()[in_index]->Shape();
ORT_ENFORCE(input_shape->dim_size() >= 2);
ONNX_NAMESPACE::TensorShapeProto flattened_shape;
if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) {
flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value());
} else {
std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed());
flattened_shape.add_dim()->set_dim_param(token_dim_name);
}
for (int k = 2; k < input_shape->dim_size(); k++) {
ORT_ENFORCE(input_shape->dim(k).has_dim_value());
new_shape.push_back(input_shape->dim(k).dim_value());
flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value());
}
ONNX_NAMESPACE::TensorProto new_shape_const_tensor;
new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape"));
new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
new_shape_const_tensor.add_dims(new_shape.size());
new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t));
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor);
reshape_input_args.push_back(new_shape_arg);

InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));

Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("Reshape"),
"Reshape",
"Reshape node to filter invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
"",
logger);
InlinedVector<NodeArg*> unpad_input_args;
unpad_input_args.reserve(2);
unpad_input_args.push_back(node.MutableInputDefs()[in_index]);
unpad_input_args.push_back(gather_index_arg);

new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0];

reshape_out_arg->SetShape(flattened_shape);

InlinedVector<NodeArg*> gather_input_args;
gather_input_args.reserve(2);
gather_input_args.push_back(reshape_output_args[0]);
gather_input_args.push_back(gather_index_arg);

InlinedVector<NodeArg*> gather_output_args;
gather_output_args.push_back(
InlinedVector<NodeArg*> unpad_output_args;
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"),
reshape_out_arg->TypeAsProto()));
nullptr));
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"),
nullptr));

Node* new_gather_node = InsertIntermediateNodeOnDestInput(
Node* unpad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("PaddingFilter"),
"ShrunkenGather",
"ShrunkenGather node to filter invalid tokens.",
gather_input_args,
gather_output_args,
"FlattenAndUnpad",
"FlattenAndUnpad node to filter invalid tokens.",
unpad_input_args,
unpad_output_args,
{},
kMSDomain,
logger);

new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gather_out_arg = new_gather_node->MutableOutputDefs()[0];
return gather_out_arg;
unpad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto unpad_out_arg = unpad_node->MutableOutputDefs()[0];
return unpad_out_arg;
}

// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
Expand All @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph,
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));

Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
Expand Down
3 changes: 1 addition & 2 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};

TensorInfo padded_out_info({5, 2, 3}, true);
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
Expand All @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
#endif

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
{padded_out_info, out_shape_info}, &max_error,
{padded_out_info}, &max_error,
x_datas, {}, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5786,14 +5786,14 @@ def __init__(self, vocab_size, hidden_size, pad_token_id):
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be
# added to output of test_op.
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
# output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to
# flatten and elimination padding.
def test_elementwise(self, input_ids):
input_shape = input_ids.size()
Expand Down Expand Up @@ -5905,9 +5905,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
if case >= 2:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2
gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten")

def find_input_node_type(model, arg):
Expand Down Expand Up @@ -6071,7 +6071,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4)

training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]

Expand Down
Loading

0 comments on commit 8a509ed

Please sign in to comment.