Skip to content

Commit

Permalink
Add FlattenAndUnpad Op
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Oct 10, 2023
1 parent 2ef6ee6 commit cbf9348
Show file tree
Hide file tree
Showing 16 changed files with 354 additions and 107 deletions.
17 changes: 10 additions & 7 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,13 +755,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {

IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
{GO(0), I(1)},
{GI(0), IA("No_use")})};
}

IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1},
{GO(0), I(1), O(1)},
{GI(0)})};
}

IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) {
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
Expand Down
24 changes: 21 additions & 3 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4737,14 +4737,12 @@ Return true if all elements are true and false otherwise.
" unflatten_dims: [2, 3], shape is [2]"

" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
" shape is [2, 3, 4]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
Expand All @@ -4758,6 +4756,26 @@ Return true if all elements are true and false otherwise.
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");

ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
"This is used by padding elimination graph transformers.")
.Input(0, "input", "input data of rank N, shape is [M1, M2, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",

Check warning on line 4766 in orttraining/orttraining/core/graph/training_op_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/graph/training_op_defs.cc#L4766

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/core/graph/training_op_defs.cc:4766:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"T_INT")
.Output(0, "output", "output data of rank N-1, [d1, d2, ..., dN]", "T")
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices and shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.");

ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph,
return new_expand_node->MutableOutputDefs()[0];
}

// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// Insert FlattenAndUnpad to flatten the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertFlattenPatternForInput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gather_index_arg,
const logging::Logger& logger) {
InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.reserve(2);
reshape_input_args.push_back(node.MutableInputDefs()[in_index]);
std::vector<int64_t> new_shape;
new_shape.push_back(-1); // only support flatten 0 and 1 dims
auto input_shape = node.InputDefs()[in_index]->Shape();
ORT_ENFORCE(input_shape->dim_size() >= 2);
ONNX_NAMESPACE::TensorShapeProto flattened_shape;
if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) {
flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value());
} else {
std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed());
flattened_shape.add_dim()->set_dim_param(token_dim_name);
}
for (int k = 2; k < input_shape->dim_size(); k++) {
ORT_ENFORCE(input_shape->dim(k).has_dim_value());
new_shape.push_back(input_shape->dim(k).dim_value());
flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value());
}
ONNX_NAMESPACE::TensorProto new_shape_const_tensor;
new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape"));
new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
new_shape_const_tensor.add_dims(new_shape.size());
new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t));
NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor);
reshape_input_args.push_back(new_shape_arg);

InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));

Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("Reshape"),
"Reshape",
"Reshape node to filter invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
"",
logger);
InlinedVector<NodeArg*> unpad_input_args;
unpad_input_args.reserve(2);
unpad_input_args.push_back(node.MutableInputDefs()[in_index]);
unpad_input_args.push_back(gather_index_arg);

new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0];

reshape_out_arg->SetShape(flattened_shape);

InlinedVector<NodeArg*> gather_input_args;
gather_input_args.reserve(2);
gather_input_args.push_back(reshape_output_args[0]);
gather_input_args.push_back(gather_index_arg);

InlinedVector<NodeArg*> gather_output_args;
gather_output_args.push_back(
InlinedVector<NodeArg*> unpad_output_args;
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"),
reshape_out_arg->TypeAsProto()));
nullptr));
unpad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"),
nullptr));

Node* new_gather_node = InsertIntermediateNodeOnDestInput(
Node* unpad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("PaddingFilter"),
"ShrunkenGather",
"ShrunkenGather node to filter invalid tokens.",
gather_input_args,
gather_output_args,
"FlattenAndUnpad",
"FlattenAndUnpad node to filter invalid tokens.",
unpad_input_args,
unpad_output_args,
{},
kMSDomain,
logger);

new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gather_out_arg = new_gather_node->MutableOutputDefs()[0];
return gather_out_arg;
unpad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto unpad_out_arg = unpad_node->MutableOutputDefs()[0];
return unpad_out_arg;
}

// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
Expand All @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph,
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));

Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
Expand Down
3 changes: 1 addition & 2 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
std::vector<std::vector<float>> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}};

TensorInfo padded_out_info({5, 2, 3}, true);
TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType<int64_t>());

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
Expand All @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) {
#endif

ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info},
{padded_out_info, out_shape_info}, &max_error,
{padded_out_info}, &max_error,
x_datas, {}, true, false, &execution_providers));
EXPECT_IS_TINY(max_error);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5901,9 +5901,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1
if case >= 2:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3
else:
assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1
assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2
gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten")

def find_input_node_type(model, arg):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

#if defined(USE_CUDA) || defined(USE_ROCM)

TEST(FlattenAndUnpadTest, FloatType1D) {
std::vector<float> input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> unflatten_dims = {5, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<float>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, FloatType2D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> unflatten_dims = {2, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<float>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, MLFloat16Type1D) {
std::vector<float> input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f};
std::vector<int64_t> unflatten_dims = {5, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {5, 3}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddOutput<float>("output", {6}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

TEST(FlattenAndUnpadTest, MLFloat16Type2D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};

std::vector<float> output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f};
std::vector<int64_t> unflatten_dims = {2, 3};

OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {2, 3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddOutput<float>("output", {3, 3}, output);
test.AddOutput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.Run();
}

#endif

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ TEST(PadAndUnflattenTest, FloatType1D) {
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};

std::vector<int64_t> full_flatten_dims = {15};

OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {6}, input);
test.AddInput<int64_t>("indices", {6}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<float>("output", {5, 3}, output);
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
test.Run();
}

Expand All @@ -36,14 +33,11 @@ TEST(PadAndUnflattenTest, FloatType2D) {
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};

std::vector<int64_t> full_flatten_dims = {6, 3};

OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", {3, 3}, input);
test.AddInput<int64_t>("indices", {3}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<float>("output", {2, 3, 3}, output);
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
test.Run();
}

Expand All @@ -55,8 +49,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) {
std::vector<float> output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};

std::vector<int64_t> full_flatten_dims = {15};

std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
Expand All @@ -69,7 +61,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) {
test.AddInput<int64_t>("indices", {6}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<MLFloat16>("output", {5, 3}, output_half);
test.AddOutput<int64_t>("full_flatten_dims", {1}, full_flatten_dims);
test.Run();
}

Expand All @@ -81,8 +72,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) {
std::vector<float> output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};

std::vector<int64_t> full_flatten_dims = {6, 3};

std::vector<MLFloat16> input_half;
input_half.resize(input.size());
ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size()));
Expand All @@ -95,7 +84,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) {
test.AddInput<int64_t>("indices", {3}, indices);
test.AddInput<int64_t>("unflatten_dims", {2}, unflatten_dims);
test.AddOutput<MLFloat16>("output", {2, 3, 3}, output_half);
test.AddOutput<int64_t>("full_flatten_dims", {2}, full_flatten_dims);
test.Run();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum);

// the kernels within the following ifdef are not included in a build with
Expand Down Expand Up @@ -459,6 +460,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum)>,
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
Expand Down
Loading

0 comments on commit cbf9348

Please sign in to comment.