From cbf9348f7a0c9107253d99c4e37ee12bd5e9c5f3 Mon Sep 17 00:00:00 2001 From: guyang3532 Date: Mon, 9 Oct 2023 11:10:47 +0000 Subject: [PATCH] Add FlattenAndUnpad Op --- .../core/graph/gradient_builder.cc | 17 ++-- .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/graph/training_op_defs.cc | 24 ++++- .../compute_optimizer/padding_elimination.cc | 90 ++++--------------- .../test/gradient/gradient_ops_test.cc | 3 +- .../python/orttraining_test_ortmodule_api.py | 4 +- .../cuda/flatten_and_unpad_test.cc | 79 ++++++++++++++++ .../cuda/pad_and_unflatten_test.cc | 12 --- .../cuda/cuda_training_kernels.cc | 2 + .../cuda/tensor/flatten_and_unpad.cc | 86 ++++++++++++++++++ .../cuda/tensor/flatten_and_unpad.h | 21 +++++ .../cuda/tensor/flatten_and_unpad_impl.cu | 83 +++++++++++++++++ .../cuda/tensor/flatten_and_unpad_impl.h | 25 ++++++ .../cuda/tensor/pad_and_unflatten.cc | 11 +-- .../rocm/rocm_training_kernels.cc | 2 + 16 files changed, 354 insertions(+), 107 deletions(-) create mode 100644 orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index b3da4f3977ff2..0dd00085d7661 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -755,13 +755,16 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) { IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) { return std::vector{ - NodeDef(OpDef("Reshape"), - {GO(0), O(1)}, - {IA("GO_reshaped")}), - NodeDef(OpDef{"Gather", kOnnxDomain, 1}, - {IA("GO_reshaped"), I(1)}, - {GI(0)}, - SrcNodeAttributes())}; + NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1}, + {GO(0), I(1)}, + {GI(0), IA("No_use")})}; +} + +IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) { + return std::vector{ + NodeDef(OpDef{"PadAndUnflatten", kMSDomain, 1}, + {GO(0), I(1), O(1)}, + {GI(0)})}; } IMPLEMENT_GRADIENT_BUILDER(GetShrunkenGatherGradient) { diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index a517e8af13fcc..6cf3d728625fa 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -40,6 +40,7 @@ DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient) DECLARE_GRADIENT_BUILDER(GetGatherGradient) DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient) +DECLARE_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient) DECLARE_GRADIENT_BUILDER(GetConvGradient) DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4062b5d097394..0a503d96edb52 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -71,6 +71,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); + REGISTER_GRADIENT_BUILDER("FlattenAndUnpad", GetFlattenAndUnpadGradient); REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient); REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient); REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 5cd29303c3639..3c94d37405e0f 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4737,14 +4737,12 @@ Return true if all elements are true and false otherwise. " unflatten_dims: [2, 3], shape is [2]" " output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]]," - " shape is [2, 3, 4]" - " flatten_output_shape: [6, 4], shape is [2]") + " shape is [2, 3, 4]") .Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T") .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", "T_INDEX") .Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") .Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T") - .Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT") .TypeConstraint( "T_INT", {"tensor(int32)", "tensor(int64)"}, @@ -4758,6 +4756,26 @@ Return true if all elements are true and false otherwise. {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); +ONNX_CONTRIB_OPERATOR_SCHEMA(FlattenAndUnpad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc( + "FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices." + "This is used by padding elimination graph transformers.") + .Input(0, "input", "input data of rank N, shape is [M1, M2, d2, ..., dN]", "T") + .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", + "T_INT") + .Output(0, "output", "output data of rank N-1, [d1, d2, ..., dN]", "T") + .Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") + .TypeConstraint( + "T_INT", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indices and shape to integer tensors.") + .TypeConstraint( + "T", + {"tensor(int32)", "tensor(int64)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(GRUTraining) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index 74247c059cf84..6fd8cfdc13249 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -129,91 +129,43 @@ NodeArg* InsertExpandForNodeInput(Graph& graph, return new_expand_node->MutableOutputDefs()[0]; } -// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node. +// Insert FlattenAndUnpad to flatten the in_index-th input of node. // The gather_index_arg is the indices of the elements that are not padding. NodeArg* InsertFlattenPatternForInput(Graph& graph, Node& node, uint32_t in_index, NodeArg* gather_index_arg, const logging::Logger& logger) { - InlinedVector reshape_input_args; - reshape_input_args.reserve(2); - reshape_input_args.push_back(node.MutableInputDefs()[in_index]); - std::vector new_shape; - new_shape.push_back(-1); // only support flatten 0 and 1 dims - auto input_shape = node.InputDefs()[in_index]->Shape(); - ORT_ENFORCE(input_shape->dim_size() >= 2); - ONNX_NAMESPACE::TensorShapeProto flattened_shape; - if (input_shape->dim(0).has_dim_value() && input_shape->dim(1).has_dim_value()) { - flattened_shape.add_dim()->set_dim_value(input_shape->dim(0).dim_value() * input_shape->dim(1).dim_value()); - } else { - std::string token_dim_name = MakeString("total_token_count_", utils::GetRandomSeed()); - flattened_shape.add_dim()->set_dim_param(token_dim_name); - } - for (int k = 2; k < input_shape->dim_size(); k++) { - ORT_ENFORCE(input_shape->dim(k).has_dim_value()); - new_shape.push_back(input_shape->dim(k).dim_value()); - flattened_shape.add_dim()->set_dim_value(input_shape->dim(k).dim_value()); - } - ONNX_NAMESPACE::TensorProto new_shape_const_tensor; - new_shape_const_tensor.set_name(graph.GenerateNodeArgName("new_shape")); - new_shape_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - new_shape_const_tensor.add_dims(new_shape.size()); - new_shape_const_tensor.set_raw_data(new_shape.data(), new_shape.size() * sizeof(int64_t)); - NodeArg* new_shape_arg = &graph_utils::AddInitializer(graph, new_shape_const_tensor); - reshape_input_args.push_back(new_shape_arg); - - InlinedVector reshape_output_args; - reshape_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), - node.MutableInputDefs()[in_index]->TypeAsProto())); - - Node* new_reshape_node = InsertIntermediateNodeOnDestInput( - graph, node, - in_index, - 0, - 0, - graph.GenerateNodeName("Reshape"), - "Reshape", - "Reshape node to filter invalid tokens.", - reshape_input_args, - reshape_output_args, - {}, - "", - logger); + InlinedVector unpad_input_args; + unpad_input_args.reserve(2); + unpad_input_args.push_back(node.MutableInputDefs()[in_index]); + unpad_input_args.push_back(gather_index_arg); - new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto reshape_out_arg = new_reshape_node->MutableOutputDefs()[0]; - - reshape_out_arg->SetShape(flattened_shape); - - InlinedVector gather_input_args; - gather_input_args.reserve(2); - gather_input_args.push_back(reshape_output_args[0]); - gather_input_args.push_back(gather_index_arg); - - InlinedVector gather_output_args; - gather_output_args.push_back( + InlinedVector unpad_output_args; + unpad_output_args.push_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_filter_result"), - reshape_out_arg->TypeAsProto())); + nullptr)); + unpad_output_args.push_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("d1_d2_shape"), + nullptr)); - Node* new_gather_node = InsertIntermediateNodeOnDestInput( + Node* unpad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, 0, 0, graph.GenerateNodeName("PaddingFilter"), - "ShrunkenGather", - "ShrunkenGather node to filter invalid tokens.", - gather_input_args, - gather_output_args, + "FlattenAndUnpad", + "FlattenAndUnpad node to filter invalid tokens.", + unpad_input_args, + unpad_output_args, {}, kMSDomain, logger); - new_gather_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto gather_out_arg = new_gather_node->MutableOutputDefs()[0]; - return gather_out_arg; + unpad_node->SetExecutionProviderType(node.GetExecutionProviderType()); + auto unpad_out_arg = unpad_node->MutableOutputDefs()[0]; + return unpad_out_arg; } // Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node. @@ -236,10 +188,6 @@ NodeArg* InsertNodesForOutput(Graph& graph, pad_node_output_args.push_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"), nullptr)); - pad_node_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"), - nullptr)); - Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 597801f4030c1..5da69f7a99e6c 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3011,7 +3011,6 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}}; TensorInfo padded_out_info({5, 2, 3}, true); - TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); std::vector> execution_providers; #ifdef USE_CUDA @@ -3021,7 +3020,7 @@ TEST(GradientCheckerTest, PadAndUnflattenGrad) { #endif ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info}, - {padded_out_info, out_shape_info}, &max_error, + {padded_out_info}, &max_error, x_datas, {}, true, false, &execution_providers)); EXPECT_IS_TINY(max_error); } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 643d47b0d043e..4de25d152e2df 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5901,9 +5901,9 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1 if case >= 2: - assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3 else: - assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2 gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") def find_input_node_type(model, arg): diff --git a/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc new file mode 100644 index 0000000000000..3c4fa40342269 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/flatten_and_unpad_test.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/common/tensor_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +TEST(FlattenAndUnpadTest, FloatType1D) { + std::vector input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, FloatType2D) { + std::vector input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, MLFloat16Type1D) { + std::vector input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector unflatten_dims = {5, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {5, 3}, input); + test.AddInput("indices", {6}, indices); + test.AddOutput("output", {6}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +TEST(FlattenAndUnpadTest, MLFloat16Type2D) { + std::vector input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + std::vector indices = {1, 3, 4}; + + std::vector output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector unflatten_dims = {2, 3}; + + OpTester test("FlattenAndUnpad", 1, onnxruntime::kMSDomain); + test.AddInput("input", {2, 3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddOutput("output", {3, 3}, output); + test.AddOutput("unflatten_dims", {2}, unflatten_dims); + test.Run(); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc index a800f17e59ae0..9a86955e09379 100644 --- a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc @@ -17,14 +17,11 @@ TEST(PadAndUnflattenTest, FloatType1D) { std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {15}; - OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); test.AddInput("input", {6}, input); test.AddInput("indices", {6}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {5, 3}, output); - test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); test.Run(); } @@ -36,14 +33,11 @@ TEST(PadAndUnflattenTest, FloatType2D) { std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {6, 3}; - OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); test.AddInput("input", {3, 3}, input); test.AddInput("indices", {3}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {2, 3, 3}, output); - test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); test.Run(); } @@ -55,8 +49,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) { std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {15}; - std::vector input_half; input_half.resize(input.size()); ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); @@ -69,7 +61,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type1D) { test.AddInput("indices", {6}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {5, 3}, output_half); - test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); test.Run(); } @@ -81,8 +72,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) { std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; - std::vector full_flatten_dims = {6, 3}; - std::vector input_half; input_half.resize(input.size()); ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); @@ -95,7 +84,6 @@ TEST(PadAndUnflattenTest, MLFloat16Type2D) { test.AddInput("indices", {3}, indices); test.AddInput("unflatten_dims", {2}, unflatten_dims); test.AddOutput("output", {2, 3, 3}, output_half); - test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); test.Run(); } diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 8e61dbee506f2..19a36cbb536a8 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -206,6 +206,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, FlattenAndUnpad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum); // the kernels within the following ifdef are not included in a build with @@ -459,6 +460,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad)>, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc new file mode 100644 index 0000000000000..5e8f2db9c8f91 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad.h" +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + FlattenAndUnpad, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) + .OutputMemoryType(OrtMemTypeCPUOutput, 1), + FlattenAndUnpad); + +// Put implementation in the anonymous namespace to avoid name collision in the global namespace. +namespace { + +template +struct FlattenAndUnpadFunctor { + void operator()(cudaStream_t stream, + const int64_t output_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const Tensor& input_tensor, + const Tensor& indices_tensor, + Tensor& output_tensor) const { + typedef typename ToCudaType::MappedType CudaT; + const CudaT* input_data = reinterpret_cast(input_tensor.Data()); + + FlattenAndUnpadImpl(stream, output_element_count, output_element_stride_fdm, index_value_upper_bound, + input_data, indices_tensor.Data(), + reinterpret_cast(output_tensor.MutableData())); + } +}; + +} // namespace + +Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const { + const Tensor* input_tensor = context->Input(0); + const Tensor* indices_tensor = context->Input(1); + ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1, + "indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions()); + + std::vector output_shape_vec; + output_shape_vec.push_back(indices_tensor->Shape()[0]); + const auto& input_shape = input_tensor->Shape(); + int64_t element_stride = 1; + for (size_t i = 2; i < input_shape.NumDimensions(); ++i) { + output_shape_vec.push_back(input_shape[i]); + element_stride *= input_shape[i]; + } + + fast_divmod output_element_stride_fdm(static_cast(element_stride)); + auto output_shape = TensorShape(output_shape_vec); + Tensor* output_tensor = context->Output(0, output_shape); + + std::vector unflatten_dims_vec; + unflatten_dims_vec.push_back(input_shape[0]); + unflatten_dims_vec.push_back(input_shape[1]); + const int64_t index_value_upper_bound = input_shape[0] * input_shape[1]; + + utils::MLTypeCallDispatcher t_disp(input_tensor->GetElementType()); + t_disp.Invoke(Stream(context), + output_shape.Size(), + output_element_stride_fdm, + index_value_upper_bound, + *input_tensor, + *indices_tensor, + *output_tensor); + + size_t rank = unflatten_dims_vec.size(); + Tensor* unflatten_dims_tensor = context->Output(1, {static_cast(rank)}); + TensorShape(unflatten_dims_vec).CopyDims(unflatten_dims_tensor->MutableData(), rank); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h new file mode 100644 index 0000000000000..f9c6819a393b8 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace cuda { + +class FlattenAndUnpad final : public CudaKernel { + public: + FlattenAndUnpad(const OpKernelInfo& info) : CudaKernel(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu new file mode 100644 index 0000000000000..69cd0c7cd5445 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.cu @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +namespace onnxruntime { +namespace cuda { + +constexpr int kBlockSize = 256; +constexpr int kNumUnroll = 4; + +template +__global__ void ExtractIputWithIndexKernel(const CUDA_LONG N, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; + CUDA_LONG id = idx * kNumUnroll; + + T input[kNumUnroll]; + if (id < N) { +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + int row_index, col_index; + output_element_stride_fdm.divmod(li, row_index, col_index); + assert(indices_data[row_index] < index_value_upper_bound); + input[i] = input_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index]; + } + } + } + +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + output_data[li] = input[i]; + } + } +} + +template +void FlattenAndUnpadImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + const int blocksPerGrid = static_cast(CeilDiv(total_element_count, kBlockSize * kNumUnroll)); + ExtractIputWithIndexKernel<<>>( + static_cast(total_element_count), + output_element_stride_fdm, + index_value_upper_bound, + input_data, + indices_data, + output_data); +} + +#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \ + template void FlattenAndUnpadImpl(cudaStream_t stream, \ + const int64_t total_element_count, \ + const fast_divmod output_element_stride_fdm, \ + const int64_t index_value_upper_bound, \ + const T* input_data, \ + const int64_t* indices_data, \ + T* output_data); + +SPECIALIZED_RESTORE_FROM_MASK_IMPL(float) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(double) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(half) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(int32_t) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(int64_t) + +#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h new file mode 100644 index 0000000000000..75f8c243d3425 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef USE_ROCM +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#else +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#endif + +namespace onnxruntime { +namespace cuda { + +template +void FlattenAndUnpadImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data); + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc index caf89ef840e0c..7bd759e8976c1 100644 --- a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -17,8 +17,7 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", BuildKernelDefConstraints()) .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) .TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .OutputMemoryType(OrtMemTypeCPUOutput, 1), + .InputMemoryType(OrtMemTypeCPUInput, 2), PadAndUnflatten); // Put implementation in the anonymous namespace to avoid name collision in the global namespace. @@ -63,14 +62,11 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { output_shape_vec.push_back(dims_ptr[0]); output_shape_vec.push_back(dims_ptr[1]); - std::vector full_size_flatten_shape_vec; const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1]; - full_size_flatten_shape_vec.push_back(flatten_dim_factor); int64_t element_stride = 1; for (size_t i = 1; i < input_shape.NumDimensions(); ++i) { output_shape_vec.push_back(input_shape[i]); - full_size_flatten_shape_vec.push_back(input_shape[i]); element_stride *= input_shape[i]; } @@ -87,11 +83,6 @@ Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { *indices_tensor, *output_tensor); - // Set input shape output tensor. - size_t rank = full_size_flatten_shape_vec.size(); - Tensor* input_shape_tensor = context->Output(1, {static_cast(rank)}); - TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData(), rank); - return Status::OK(); } diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2321aa23dd6eb..c959bd6c2eb6e 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -187,6 +187,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -387,6 +388,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI)