Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlattenAndUnpad Op #17845

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {

IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
{GO(0), I(1)},
{GI(0), IA("Unflatten_dims")})};
}

IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1},
{GO(0), I(1), O(1)},
{GI(0)})};
}

IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
Expand Down
26 changes: 22 additions & 4 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4741,22 +4741,20 @@
"For other indices, the corresponding value in output will be padded to zero."

"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
"(in case of performance concern), the behaviour of output is undefined."
"(in case of performance concern), the behavior of output is undefined."

"An example:"
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
" indices: [0, 5], shape is [2]"
" unflatten_dims: [2, 3], shape is [2]"

" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
" shape is [2, 3, 4]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
Expand All @@ -4770,6 +4768,26 @@
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");

ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
"This is used by padding elimination graph transformer.")
.Input(0, "input", "input data of rank N + 1, shape is [M1, M2, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",

Check warning on line 4778 in orttraining/orttraining/core/graph/training_op_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/graph/training_op_defs.cc#L4778

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/core/graph/training_op_defs.cc:4778:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"T_INT")
.Output(0, "output", "output data of rank N, [d1, d2, ..., dN]", "T")
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices and shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.");

ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph,
return new_expand_node->MutableOutputDefs()[0];
}

// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// Insert FlattenAndUnpad to flatten and unpad the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertFlattenPatternForInput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gather_index_arg,
const logging::Logger& logger) {
InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.reserve(2);
reshape_input_args.push_back(node.MutableInputDefs()[in_index]);
std::vector<int64_t> new_shape;
new_shape.push_back(-1); // only support flatten 0 and 1 dims
auto input_shape = node.InputDefs()[in_index]->Shape();
ORT_ENFORCE(input_shape->dim_size() >= 2);
ONNX_NAMESPACE::TensorShapeProto flattened_shape;
if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) {
flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value());
} else {
std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed());
flattened_shape.add_dim()->set_dim_param(token_dim_name);
}
for (int k = 2; k < input_shape->dim_size(); k++) {
ORT_ENFORCE(input_shape->dim(k).has_dim_value());
new_shape.push_back(input_shape->dim(k).dim_value());
flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value());
}
ONNX_NAMESPACE::TensorProto new_shape_const_tensor;
new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape"));
new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
new_shape_const_tensor.add_dims(new_shape.size());
new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t));
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor);
reshape_input_args.push_back(new_shape_arg);

InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));

Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("Reshape"),
"Reshape",
"Reshape node to filter invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
"",
logger);
InlinedVector<NodeArg*> unpad_input_args;
unpad_input_args.reserve(2);
unpad_input_args.push_back(node.MutableInputDefs()[in_index]);
unpad_input_args.push_back(gather_index_arg);

new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0];

reshape_out_arg->SetShape(flattened_shape);

InlinedVector<NodeArg*> gather_input_args;
gather_input_args.reserve(2);
gather_input_args.push_back(reshape_output_args[0]);
gather_input_args.push_back(gather_index_arg);

InlinedVector<NodeArg*> gather_output_args;
gather_output_args.push_back(
InlinedVector<NodeArg*> unpad_output_args;
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"),
reshape_out_arg->TypeAsProto()));
nullptr));
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"),
nullptr));

Node* new_gather_node = InsertIntermediateNodeOnDestInput(
Node* unpad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("PaddingFilter"),
"ShrunkenGather",
"ShrunkenGather node to filter invalid tokens.",
gather_input_args,
gather_output_args,
"FlattenAndUnpad",
"FlattenAndUnpad node to filter invalid tokens.",
unpad_input_args,
unpad_output_args,
{},
kMSDomain,
logger);

new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gather_out_arg = new_gather_node->MutableOutputDefs()[0];
return gather_out_arg;
unpad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto unpad_out_arg = unpad_node->MutableOutputDefs()[0];
return unpad_out_arg;
}

// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
Expand All @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph,
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));

Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
Expand Down
3 changes: 1 addition & 2 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};

TensorInfo padded_out_info({5, 2, 3}, true);
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
Expand All @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
#endif

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
{padded_out_info, out_shape_info}, &max_error,
{padded_out_info}, &max_error,
x_datas, {}, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5786,14 +5786,14 @@ def __init__(self, vocab_size, hidden_size, pad_token_id):
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be
# added to output of test_op.
# in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size],
# the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather'
# the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad'
# pattern should be insert to the arg of [batch_size, 1, hidden_size].
# in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size],
# the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to
# output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to
# output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to
# flatten and elimination padding.
def test_elementwise(self, input_ids):
input_shape = input_ids.size()
Expand Down Expand Up @@ -5905,9 +5905,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
if case >= 2:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3
guyang3532 marked this conversation as resolved.
Show resolved Hide resolved
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2
gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten")

def find_input_node_type(model, arg):
Expand Down Expand Up @@ -6071,7 +6071,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4)

training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model
assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node]
assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node]
assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node]
del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"]

Expand Down
Loading
Loading