diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 7100cedaf78a0..d55297de241b3 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -791,13 +791,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) { IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) { return std::vector{ - NodeDef(OpDef("Reshape"), - {GO(0), O(1)}, - {IA("GO_reshaped")}), - NodeDef(OpDef{"Gather", kOnnxDomain, 1}, - {IA("GO_reshaped"), I(1)}, - {GI(0)}, - SrcNodeAttributes())}; + NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1}, + {GO(0), I(1)}, + {GI(0), IA("No_use")})}; +} + +IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) { + return std::vector{ + NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1}, + {GO(0), I(1), O(1)}, + {GI(0)})}; } IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 08987a86ebda9..92bfae9cd83a4 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient) DECLARE_GRADIENT_BUILDER(GetGatherGradient) DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient) +DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient) DECLARE_GRADIENT_BUILDER(GetConvGradient) DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index f280a02cb490f..ea56be9e6dfa3 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -72,6 +72,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); + REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient); REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient); REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient); REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 80d937fa163e6..106ff16aef7dd 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4741,7 +4741,7 @@ Return true if all elements are true and false otherwise. "For other indices, the corresponding value in output will be padded to zero." "The indices don't allow duplicated index values, otherwise, though there is no runtime check" - "(in case of performance concern), the behaviour of output is undefined." + "(in case of performance concern), the behavior of output is undefined." "An example:" " input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]" @@ -4749,14 +4749,12 @@ Return true if all elements are true and false otherwise. " unflatten_dims: [2, 3], shape is [2]" " output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]]," - " shape is [2, 3, 4]" - " flatten_output_shape: [6, 4], shape is [2]") + " shape is [2, 3, 4]") .Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T") .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", "T_INDEX") .Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") .Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T") - .Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT") .TypeConstraint( "T_INT", {"tensor(int32)", "tensor(int64)"}, @@ -4770,6 +4768,26 @@ Return true if all elements are true and false otherwise. {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc( + "FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices." + "This is used by padding elimination graph transformers.") + .Input(0, "input", "input data of rank N, shape is [M1, M2, d2, ..., dN]", "T") + .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", + "T_INT") + .Output(0, "output", "output data of rank N-1, [d1, d2, ..., dN]", "T") + .Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") + .TypeConstraint( + "T_INT", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indices and shape to integer tensors.") + .TypeConstraint( + "T", + {"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 74247c059cf84..73638e8ba62a0 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph, return new_expand_node->MutableOutputDefs()[0]; } -// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node. +// Insert FlattenAndUnpad to flatten and unpad the in_index-th input of node. // The gather_index_arg is the indices of the elements that are not padding. NodeArg* InsertFlattenPatternForInput(Graph& graph, Node& node, uint32_t in_index, NodeArg* gather_index_arg, const logging::Logger& logger) { - InlinedVector reshape_input_args; - reshape_input_args.reserve(2); - reshape_input_args.push_back(node.MutableInputDefs()[in_index]); - std::vector new_shape; - new_shape.push_back(-1); // only support flatten 0 and 1 dims - auto input_shape = node.InputDefs()[in_index]->Shape(); - ORT_ENFORCE(input_shape->dim_size() >= 2); - ONNX_NAMESPACE::TensorShapeProto flattened_shape; - if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) { - flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value()); - } else { - std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed()); - flattened_shape.add_dim()->set_dim_param(token_dim_name); - } - for (int k = 2; k < input_shape->dim_size(); k++) { - ORT_ENFORCE(input_shape->dim(k).has_dim_value()); - new_shape.push_back(input_shape->dim(k).dim_value()); - flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value()); - } - ONNX_NAMESPACE::TensorProto new_shape_const_tensor; - new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape")); - new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - new_shape_const_tensor.add_dims(new_shape.size()); - new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t)); - NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor); - reshape_input_args.push_back(new_shape_arg); - - InlinedVector reshape_output_args; - reshape_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), - node.MutableInputDefs()[in_index]->TypeAsProto())); - - Node* new_reshape_node = InsertIntermediateNodeOnDestInput( - graph, node, - in_index, - 0, - 0, - graph.GenerateNodeName("Reshape"), - "Reshape", - "Reshape node to filter invalid tokens.", - reshape_input_args, - reshape_output_args, - {}, - "", - logger); + InlinedVector unpad_input_args; + unpad_input_args.reserve(2); + unpad_input_args.push_back(node.MutableInputDefs()[in_index]); + unpad_input_args.push_back(gather_index_arg); - new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0]; - - reshape_out_arg->SetShape(flattened_shape); - - InlinedVector gather_input_args; - gather_input_args.reserve(2); - gather_input_args.push_back(reshape_output_args[0]); - gather_input_args.push_back(gather_index_arg); - - InlinedVector gather_output_args; - gather_output_args.push_back( + InlinedVector unpad_output_args; + unpad_output_args.push_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"), - reshape_out_arg->TypeAsProto())); + nullptr)); + unpad_output_args.push_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"), + nullptr)); - Node* new_gather_node = InsertIntermediateNodeOnDestInput( + Node* unpad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, 0, 0, graph.GenerateNodeName("PaddingFilter"), - "ShrunkenGather", - "ShrunkenGather node to filter invalid tokens.", - gather_input_args, - gather_output_args, + "FlattenAndUnpad", + "FlattenAndUnpad node to filter invalid tokens.", + unpad_input_args, + unpad_output_args, {}, kMSDomain, logger); - new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto gather_out_arg = new_gather_node->MutableOutputDefs()[0]; - return gather_out_arg; + unpad_node->SetExecutionProviderType(node.GetExecutionProviderType()); + auto unpad_out_arg = unpad_node->MutableOutputDefs()[0]; + return unpad_out_arg; } // Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node. @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph, pad_node_output_args.push_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"), nullptr)); - pad_node_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"), - nullptr)); - Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 890a1bbccbc92..6fb42dd59b6a0 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}}; TensorInfo padded_out_info({5, 2, 3}, true); - TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); std::vector> execution_providers; #ifdef USE_CUDA @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { #endif ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info}, - {padded_out_info, out_shape_info}, &max_error, + {padded_out_info}, &max_error, x_datas, {}, true, false, &execution_providers)); EXPECT_IS_TINY(max_error); } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index c8ec2e52f3078..13024b81f4b3c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5786,14 +5786,14 @@ def __init__(self, vocab_size, hidden_size, pad_token_id): # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be # added to output of test_op. # in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size], - # the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather' + # the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad' # pattern should be insert to the arg of [batch_size, 1, hidden_size]. # in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size], - # the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather' + # the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad' # pattern should be insert to the arg of [batch_size, 1, hidden_size]. # in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size], # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to - # output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to + # output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to # flatten and elimination padding. def test_elementwise(self, input_ids): input_shape = input_ids.size() @@ -5905,9 +5905,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1 if case >= 2: - assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3 else: - assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2 gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") def find_input_node_type(model, arg): @@ -6071,7 +6071,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4) training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model - assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node] + assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node] assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node] del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] diff --git a/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc new file mode 100644 index 0000000000000..e77afd4eaa90a --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/common/tensor_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +TEST(FlattenAndUnpadTest, Int32Type1D) { + std::vector input = {1, 1, 3, 2, 0, 3, 0, 4, + 0, 5, 0, 6, 0, 0, 0}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1, 2, 3, 4, 5, 6}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, Int32Type2D) { + std::vector input = {0, 0, 0, 1, 2, 3, 0, 0, 0, + 4, 5, 6, 7, 8, 9, 0, 0, 0}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, Int64Type1D) { + std::vector input = {1, 1, 3, 2, 0, 3, 0, 4, + 0, 5, 0, 6, 0, 0, 0}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1, 2, 3, 4, 5, 6}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, Int64Type2D) { + std::vector input = {0, 0, 0, 1, 2, 3, 0, 0, 0, + 4, 5, 6, 7, 8, 9, 0, 0, 0}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, FloatType1D) { + std::vector input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, FloatType2D) { + std::vector input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, MLFloat16Type1D) { + std::vector input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector unflatten_dims = {5, 3}; + + std::vector input_half; + input_half.resize(input.size()); + ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast(input.size())); + std::vector output_half; + output_half.resize(output.size()); + ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast(output.size())); + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input_half); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output_half); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, MLFloat16Type2D) { + std::vector input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector unflatten_dims = {2, 3}; + + std::vector input_half; + input_half.resize(input.size()); + ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast(input.size())); + std::vector output_half; + output_half.resize(output.size()); + ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast(output.size())); + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input_half); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output_half); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc index a800f17e59ae0..9a86955e09379 100644 --- a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc @@ -17,14 +17,11 @@ TEST(PadAndUnflattenTest, FloatType1D) { std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {15}; - OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); test.AddInput("input", {6}, input); test.AddInput("indices", {6}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {5, 3}, output); - test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); test.Run(); } @@ -36,14 +33,11 @@ TEST(PadAndUnflattenTest, FloatType2D) { std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {6, 3}; - OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); test.AddInput("input", {3, 3}, input); test.AddInput("indices", {3}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {2, 3, 3}, output); - test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); test.Run(); } @@ -55,8 +49,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) { std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {15}; - std::vector input_half; input_half.resize(input.size()); ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); @@ -69,7 +61,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) { test.AddInput("indices", {6}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {5, 3}, output_half); - test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); test.Run(); } @@ -81,8 +72,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) { std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {6, 3}; - std::vector input_half; input_half.resize(input.size()); ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); @@ -95,7 +84,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) { test.AddInput("indices", {3}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {2, 3, 3}, output_half); - test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); test.Run(); } diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index ae4f48b6b49a2..69f3c8a6756c3 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -206,6 +206,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad); @@ -460,6 +461,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc new file mode 100644 index 0000000000000..c0897a6d0e20f --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad.h" +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + FlattenAndUnpad, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) + .OutputMemoryType(OrtMemTypeCPUOutput, 1), + FlattenAndUnpad); + +// Put implementation in the anonymous namespace to avoid name collision in the global namespace. +namespace { + +template +struct FlattenAndUnpadFunctor { + void operator()(cudaStream_t stream, + const int64_t output_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const Tensor& input_tensor, + const Tensor& indices_tensor, + Tensor& output_tensor) const { + typedef typename ToCudaType::MappedType CudaT; + const CudaT* input_data = reinterpret_cast(input_tensor.Data()); + + FlattenAndUnpadImpl(stream, output_element_count, output_element_stride_fdm, index_value_upper_bound, + input_data, indices_tensor.Data(), + reinterpret_cast(output_tensor.MutableData())); + } +}; + +} // namespace + +Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const { + const Tensor* input_tensor = context->Input(0); + const Tensor* indices_tensor = context->Input(1); + ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1, + "indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions()); + + std::vector output_shape_vec; + output_shape_vec.push_back(indices_tensor->Shape()[0]); + const auto& input_shape = input_tensor->Shape(); + int64_t element_stride = 1; + for (size_t i = 2; i < input_shape.NumDimensions(); ++i) { + output_shape_vec.push_back(input_shape[i]); + element_stride *= input_shape[i]; + } + + fast_divmod output_element_stride_fdm(static_cast(element_stride)); + auto output_shape = TensorShape(output_shape_vec); + Tensor* output_tensor = context->Output(0, output_shape); + + std::vector unflatten_dims_vec; + unflatten_dims_vec.push_back(input_shape[0]); + unflatten_dims_vec.push_back(input_shape[1]); + const int64_t index_value_upper_bound = input_shape[0] * input_shape[1]; + + utils::MLTypeCallDispatcher + t_disp(input_tensor->GetElementType()); + t_disp.Invoke(Stream(context), + output_shape.Size(), + output_element_stride_fdm, + index_value_upper_bound, + *input_tensor, + *indices_tensor, + *output_tensor); + + size_t rank = unflatten_dims_vec.size(); + Tensor* unflatten_dims_tensor = context->Output(1, {static_cast(rank)}); + TensorShape(unflatten_dims_vec).CopyDims(unflatten_dims_tensor->MutableData(), rank); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h new file mode 100644 index 0000000000000..f9c6819a393b8 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace cuda { + +class FlattenAndUnpad final : public CudaKernel { + public: + FlattenAndUnpad(const OpKernelInfo& info) : CudaKernel(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu new file mode 100644 index 0000000000000..69cd0c7cd5445 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +namespace onnxruntime { +namespace cuda { + +constexpr int kBlockSize = 256; +constexpr int kNumUnroll = 4; + +template +__global__ void ExtractIputWithIndexKernel(const CUDA_LONG N, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; + CUDA_LONG id = idx * kNumUnroll; + + T input[kNumUnroll]; + if (id < N) { +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + int row_index, col_index; + output_element_stride_fdm.divmod(li, row_index, col_index); + assert(indices_data[row_index] < index_value_upper_bound); + input[i] = input_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index]; + } + } + } + +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + output_data[li] = input[i]; + } + } +} + +template +void FlattenAndUnpadImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + const int blocksPerGrid = static_cast(CeilDiv(total_element_count, kBlockSize * kNumUnroll)); + ExtractIputWithIndexKernel<<>>( + static_cast(total_element_count), + output_element_stride_fdm, + index_value_upper_bound, + input_data, + indices_data, + output_data); +} + +#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \ + template void FlattenAndUnpadImpl(cudaStream_t stream, \ + const int64_t total_element_count, \ + const fast_divmod output_element_stride_fdm, \ + const int64_t index_value_upper_bound, \ + const T* input_data, \ + const int64_t* indices_data, \ + T* output_data); + +SPECIALIZED_RESTORE_FROM_MASK_IMPL(float) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(double) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(half) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(int32_t) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(int64_t) + +#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h new file mode 100644 index 0000000000000..75f8c243d3425 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef USE_ROCM +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#else +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#endif + +namespace onnxruntime { +namespace cuda { + +template +void FlattenAndUnpadImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data); + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index caf89ef840e0c..7bd759e8976c1 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -17,8 +17,7 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", BuildKernelDefConstraints()) .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) .TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .OutputMemoryType(OrtMemTypeCPUOutput, 1), + .InputMemoryType(OrtMemTypeCPUInput, 2), PadAndUnflatten); // Put implementation in the anonymous namespace to avoid name collision in the global namespace. @@ -63,14 +62,11 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { output_shape_vec.push_back(dims_ptr[0]); output_shape_vec.push_back(dims_ptr[1]); - std::vector full_size_flatten_shape_vec; const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1]; - full_size_flatten_shape_vec.push_back(flatten_dim_factor); int64_t element_stride = 1; for (size_t i = 1; i < input_shape.NumDimensions(); ++i) { output_shape_vec.push_back(input_shape[i]); - full_size_flatten_shape_vec.push_back(input_shape[i]); element_stride *= input_shape[i]; } @@ -87,11 +83,6 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { *indices_tensor, *output_tensor); - // Set input shape output tensor. - size_t rank = full_size_flatten_shape_vec.size(); - Tensor* input_shape_tensor = context->Output(1, {static_cast(rank)}); - TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData(), rank); - return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index e0749c2fb4d0d..4768e688b93ac 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -190,6 +190,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadA class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -393,6 +394,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI)