diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 133cab71f2b1c..868f25f227a59 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -755,13 +755,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) { IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) { return std::vector{ - NodeDef(OpDef("Reshape"), - {GO(0), O(1)}, - {IA("GO_reshaped")}), - NodeDef(OpDef{"Gather", kOnnxDomain, 1}, - {IA("GO_reshaped"), I(1)}, - {GI(0)}, - SrcNodeAttributes())}; + NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1}, + {GO(0), I(1)}, + {GI(0), IA("No_use")})}; +} + +IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) { + return std::vector{ + NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1}, + {GO(0), I(1), O(1)}, + {GI(0)})}; } IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index a517e8af13fcc..6cf3d728625fa 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient) DECLARE_GRADIENT_BUILDER(GetGatherGradient) DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient) +DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient) DECLARE_GRADIENT_BUILDER(GetConvGradient) DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4062b5d097394..0a503d96edb52 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -71,6 +71,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); + REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient); REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient); REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient); REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index cfc79455c43ed..7361e7419a9ce 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4740,7 +4740,7 @@ Return true if all elements are true and false otherwise. "For other indices, the corresponding value in output will be padded to zero." "The indices don't allow duplicated index values, otherwise, though there is no runtime check" - "(in case of performance concern), the behaviour of output is undefined." + "(in case of performance concern), the behavior of output is undefined." "An example:" " input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]" @@ -4748,14 +4748,12 @@ Return true if all elements are true and false otherwise. " unflatten_dims: [2, 3], shape is [2]" " output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]]," - " shape is [2, 3, 4]" - " flatten_output_shape: [6, 4], shape is [2]") + " shape is [2, 3, 4]") .Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T") .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", "T_INDEX") .Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") .Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T") - .Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT") .TypeConstraint( "T_INT", {"tensor(int32)", "tensor(int64)"}, @@ -4769,6 +4767,26 @@ Return true if all elements are true and false otherwise. {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc( + "FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices." + "This is used by padding elimination graph transformers.") + .Input(0, "input", "input data of rank N, shape is [M1, M2, d2, ..., dN]", "T") + .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", + "T_INT") + .Output(0, "output", "output data of rank N-1, [d1, d2, ..., dN]", "T") + .Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") + .TypeConstraint( + "T_INT", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indices and shape to integer tensors.") + .TypeConstraint( + "T", + {"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 74247c059cf84..73638e8ba62a0 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph, return new_expand_node->MutableOutputDefs()[0]; } -// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node. +// Insert FlattenAndUnpad to flatten and unpad the in_index-th input of node. // The gather_index_arg is the indices of the elements that are not padding. NodeArg* InsertFlattenPatternForInput(Graph& graph, Node& node, uint32_t in_index, NodeArg* gather_index_arg, const logging::Logger& logger) { - InlinedVector reshape_input_args; - reshape_input_args.reserve(2); - reshape_input_args.push_back(node.MutableInputDefs()[in_index]); - std::vector new_shape; - new_shape.push_back(-1); // only support flatten 0 and 1 dims - auto input_shape = node.InputDefs()[in_index]->Shape(); - ORT_ENFORCE(input_shape->dim_size() >= 2); - ONNX_NAMESPACE::TensorShapeProto flattened_shape; - if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) { - flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value()); - } else { - std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed()); - flattened_shape.add_dim()->set_dim_param(token_dim_name); - } - for (int k = 2; k < input_shape->dim_size(); k++) { - ORT_ENFORCE(input_shape->dim(k).has_dim_value()); - new_shape.push_back(input_shape->dim(k).dim_value()); - flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value()); - } - ONNX_NAMESPACE::TensorProto new_shape_const_tensor; - new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape")); - new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - new_shape_const_tensor.add_dims(new_shape.size()); - new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t)); - NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor); - reshape_input_args.push_back(new_shape_arg); - - InlinedVector reshape_output_args; - reshape_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), - node.MutableInputDefs()[in_index]->TypeAsProto())); - - Node* new_reshape_node = InsertIntermediateNodeOnDestInput( - graph, node, - in_index, - 0, - 0, - graph.GenerateNodeName("Reshape"), - "Reshape", - "Reshape node to filter invalid tokens.", - reshape_input_args, - reshape_output_args, - {}, - "", - logger); + InlinedVector unpad_input_args; + unpad_input_args.reserve(2); + unpad_input_args.push_back(node.MutableInputDefs()[in_index]); + unpad_input_args.push_back(gather_index_arg); - new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0]; - - reshape_out_arg->SetShape(flattened_shape); - - InlinedVector gather_input_args; - gather_input_args.reserve(2); - gather_input_args.push_back(reshape_output_args[0]); - gather_input_args.push_back(gather_index_arg); - - InlinedVector gather_output_args; - gather_output_args.push_back( + InlinedVector unpad_output_args; + unpad_output_args.push_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"), - reshape_out_arg->TypeAsProto())); + nullptr)); + unpad_output_args.push_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"), + nullptr)); - Node* new_gather_node = InsertIntermediateNodeOnDestInput( + Node* unpad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, 0, 0, graph.GenerateNodeName("PaddingFilter"), - "ShrunkenGather", - "ShrunkenGather node to filter invalid tokens.", - gather_input_args, - gather_output_args, + "FlattenAndUnpad", + "FlattenAndUnpad node to filter invalid tokens.", + unpad_input_args, + unpad_output_args, {}, kMSDomain, logger); - new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto gather_out_arg = new_gather_node->MutableOutputDefs()[0]; - return gather_out_arg; + unpad_node->SetExecutionProviderType(node.GetExecutionProviderType()); + auto unpad_out_arg = unpad_node->MutableOutputDefs()[0]; + return unpad_out_arg; } // Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node. @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph, pad_node_output_args.push_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"), nullptr)); - pad_node_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"), - nullptr)); - Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 597801f4030c1..5da69f7a99e6c 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}}; TensorInfo padded_out_info({5, 2, 3}, true); - TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); std::vector> execution_providers; #ifdef USE_CUDA @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { #endif ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info}, - {padded_out_info, out_shape_info}, &max_error, + {padded_out_info}, &max_error, x_datas, {}, true, false, &execution_providers)); EXPECT_IS_TINY(max_error); } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 643d47b0d043e..44d556aaa31ab 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5782,14 +5782,14 @@ def __init__(self, vocab_size, hidden_size, pad_token_id): # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be # added to output of test_op. # in case 2, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, 1, hidden_size], - # the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather' + # the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad' # pattern should be insert to the arg of [batch_size, 1, hidden_size]. # in case 3, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [1, hidden_size], - # the test_op should be included in padding elimination subgraph and a 'Expand + Reshape + ShrunkenGather' + # the test_op should be included in padding elimination subgraph and a 'Expand + FlattenAndUnpad' # pattern should be insert to the arg of [batch_size, 1, hidden_size]. # in case 4, the shapes of inputs of test_op are [batch_size, seqlen, hidden_size] and [batch_size, seqlen, hidden_size], # the test_op should be included in padding elimination subgraph and the PadAndUnflatten should be added to - # output of test_op. Besides, the other input of Add should be added 'Reshape + ShrunkenGather' to + # output of test_op. Besides, the other input of Add should be added 'FlattenAndUnpad' to # flatten and elimination padding. def test_elementwise(self, input_ids): input_shape = input_ids.size() @@ -5901,9 +5901,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1 if case >= 2: - assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3 else: - assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2 gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") def find_input_node_type(model, arg): @@ -6067,7 +6067,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-4) training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model - assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node] + assert "FlattenAndUnpad" in [node.op_type for node in training_model.graph.node] assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node] del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] diff --git a/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc new file mode 100644 index 0000000000000..e77afd4eaa90a --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/common/tensor_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +TEST(FlattenAndUnpadTest, Int32Type1D) { + std::vector input = {1, 1, 3, 2, 0, 3, 0, 4, + 0, 5, 0, 6, 0, 0, 0}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1, 2, 3, 4, 5, 6}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, Int32Type2D) { + std::vector input = {0, 0, 0, 1, 2, 3, 0, 0, 0, + 4, 5, 6, 7, 8, 9, 0, 0, 0}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, Int64Type1D) { + std::vector input = {1, 1, 3, 2, 0, 3, 0, 4, + 0, 5, 0, 6, 0, 0, 0}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1, 2, 3, 4, 5, 6}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, Int64Type2D) { + std::vector input = {0, 0, 0, 1, 2, 3, 0, 0, 0, + 4, 5, 6, 7, 8, 9, 0, 0, 0}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, FloatType1D) { + std::vector input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, FloatType2D) { + std::vector input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, MLFloat16Type1D) { + std::vector input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector unflatten_dims = {5, 3}; + + std::vector input_half; + input_half.resize(input.size()); + ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast(input.size())); + std::vector output_half; + output_half.resize(output.size()); + ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast(output.size())); + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input_half); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output_half); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, MLFloat16Type2D) { + std::vector input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector unflatten_dims = {2, 3}; + + std::vector input_half; + input_half.resize(input.size()); + ConvertFloatToMLFloat16(input.data(), input_half.data(), static_cast(input.size())); + std::vector output_half; + output_half.resize(output.size()); + ConvertFloatToMLFloat16(output.data(), output_half.data(), static_cast(output.size())); + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input_half); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output_half); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc index a800f17e59ae0..9a86955e09379 100644 --- a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc @@ -17,14 +17,11 @@ TEST(PadAndUnflattenTest, FloatType1D) { std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {15}; - OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); test.AddInput("input", {6}, input); test.AddInput("indices", {6}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {5, 3}, output); - test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); test.Run(); } @@ -36,14 +33,11 @@ TEST(PadAndUnflattenTest, FloatType2D) { std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {6, 3}; - OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); test.AddInput("input", {3, 3}, input); test.AddInput("indices", {3}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {2, 3, 3}, output); - test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); test.Run(); } @@ -55,8 +49,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) { std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {15}; - std::vector input_half; input_half.resize(input.size()); ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); @@ -69,7 +61,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) { test.AddInput("indices", {6}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {5, 3}, output_half); - test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); test.Run(); } @@ -81,8 +72,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) { std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {6, 3}; - std::vector input_half; input_half.resize(input.size()); ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); @@ -95,7 +84,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) { test.AddInput("indices", {3}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {2, 3, 3}, output_half); - test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); test.Run(); } diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 8e61dbee506f2..19a36cbb536a8 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -206,6 +206,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum); // the kernels within the following ifdef are not included in a build with @@ -459,6 +460,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc new file mode 100644 index 0000000000000..c0897a6d0e20f --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad.h" +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + FlattenAndUnpad, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) + .OutputMemoryType(OrtMemTypeCPUOutput, 1), + FlattenAndUnpad); + +// Put implementation in the anonymous namespace to avoid name collision in the global namespace. +namespace { + +template +struct FlattenAndUnpadFunctor { + void operator()(cudaStream_t stream, + const int64_t output_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const Tensor& input_tensor, + const Tensor& indices_tensor, + Tensor& output_tensor) const { + typedef typename ToCudaType::MappedType CudaT; + const CudaT* input_data = reinterpret_cast(input_tensor.Data()); + + FlattenAndUnpadImpl(stream, output_element_count, output_element_stride_fdm, index_value_upper_bound, + input_data, indices_tensor.Data(), + reinterpret_cast(output_tensor.MutableData())); + } +}; + +} // namespace + +Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const { + const Tensor* input_tensor = context->Input(0); + const Tensor* indices_tensor = context->Input(1); + ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1, + "indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions()); + + std::vector output_shape_vec; + output_shape_vec.push_back(indices_tensor->Shape()[0]); + const auto& input_shape = input_tensor->Shape(); + int64_t element_stride = 1; + for (size_t i = 2; i < input_shape.NumDimensions(); ++i) { + output_shape_vec.push_back(input_shape[i]); + element_stride *= input_shape[i]; + } + + fast_divmod output_element_stride_fdm(static_cast(element_stride)); + auto output_shape = TensorShape(output_shape_vec); + Tensor* output_tensor = context->Output(0, output_shape); + + std::vector unflatten_dims_vec; + unflatten_dims_vec.push_back(input_shape[0]); + unflatten_dims_vec.push_back(input_shape[1]); + const int64_t index_value_upper_bound = input_shape[0] * input_shape[1]; + + utils::MLTypeCallDispatcher + t_disp(input_tensor->GetElementType()); + t_disp.Invoke(Stream(context), + output_shape.Size(), + output_element_stride_fdm, + index_value_upper_bound, + *input_tensor, + *indices_tensor, + *output_tensor); + + size_t rank = unflatten_dims_vec.size(); + Tensor* unflatten_dims_tensor = context->Output(1, {static_cast(rank)}); + TensorShape(unflatten_dims_vec).CopyDims(unflatten_dims_tensor->MutableData(), rank); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h new file mode 100644 index 0000000000000..f9c6819a393b8 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace cuda { + +class FlattenAndUnpad final : public CudaKernel { + public: + FlattenAndUnpad(const OpKernelInfo& info) : CudaKernel(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu new file mode 100644 index 0000000000000..69cd0c7cd5445 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +namespace onnxruntime { +namespace cuda { + +constexpr int kBlockSize = 256; +constexpr int kNumUnroll = 4; + +template +__global__ void ExtractIputWithIndexKernel(const CUDA_LONG N, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; + CUDA_LONG id = idx * kNumUnroll; + + T input[kNumUnroll]; + if (id < N) { +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + int row_index, col_index; + output_element_stride_fdm.divmod(li, row_index, col_index); + assert(indices_data[row_index] < index_value_upper_bound); + input[i] = input_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index]; + } + } + } + +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + output_data[li] = input[i]; + } + } +} + +template +void FlattenAndUnpadImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + const int blocksPerGrid = static_cast(CeilDiv(total_element_count, kBlockSize * kNumUnroll)); + ExtractIputWithIndexKernel<<>>( + static_cast(total_element_count), + output_element_stride_fdm, + index_value_upper_bound, + input_data, + indices_data, + output_data); +} + +#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \ + template void FlattenAndUnpadImpl(cudaStream_t stream, \ + const int64_t total_element_count, \ + const fast_divmod output_element_stride_fdm, \ + const int64_t index_value_upper_bound, \ + const T* input_data, \ + const int64_t* indices_data, \ + T* output_data); + +SPECIALIZED_RESTORE_FROM_MASK_IMPL(float) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(double) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(half) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(int32_t) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(int64_t) + +#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h new file mode 100644 index 0000000000000..75f8c243d3425 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef USE_ROCM +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#else +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#endif + +namespace onnxruntime { +namespace cuda { + +template +void FlattenAndUnpadImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data); + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index caf89ef840e0c..7bd759e8976c1 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -17,8 +17,7 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", BuildKernelDefConstraints()) .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) .TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .OutputMemoryType(OrtMemTypeCPUOutput, 1), + .InputMemoryType(OrtMemTypeCPUInput, 2), PadAndUnflatten); // Put implementation in the anonymous namespace to avoid name collision in the global namespace. @@ -63,14 +62,11 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { output_shape_vec.push_back(dims_ptr[0]); output_shape_vec.push_back(dims_ptr[1]); - std::vector full_size_flatten_shape_vec; const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1]; - full_size_flatten_shape_vec.push_back(flatten_dim_factor); int64_t element_stride = 1; for (size_t i = 1; i < input_shape.NumDimensions(); ++i) { output_shape_vec.push_back(input_shape[i]); - full_size_flatten_shape_vec.push_back(input_shape[i]); element_stride *= input_shape[i]; } @@ -87,11 +83,6 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { *indices_tensor, *output_tensor); - // Set input shape output tensor. - size_t rank = full_size_flatten_shape_vec.size(); - Tensor* input_shape_tensor = context->Output(1, {static_cast(rank)}); - TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData(), rank); - return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2321aa23dd6eb..c959bd6c2eb6e 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -187,6 +187,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -387,6 +388,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI)