diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc index a3ac4312053aa..dd38ee9b07ee6 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc @@ -462,6 +462,27 @@ bool LayerNormalizationGatherActor::PreCheck(const Graph& /* graph */, return true; } +bool LayerNormalizationGatherActor::PostProcess(Graph& /*graph*/, Node& current_node, + const SliceInfo& info_without_node, + const logging::Logger& /*logger*/, + const std::unordered_map& /*propagate_input_indices*/, + const std::unordered_map>& + /*all_input_cmp_rets*/, + const std::unordered_map& /*new_gather_infos*/) { + // Update LayerNormalization's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -479,6 +500,28 @@ bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, propagate_input_indices, all_input_cmp_rets, shape_update_func); } +bool SoftmaxGatherActor::PostProcess(Graph& graph, Node& current_node, const SliceInfo& info_without_node, + const logging::Logger& logger, + const std::unordered_map& propagate_input_indices, + const std::unordered_map>& all_input_cmp_rets, + const std::unordered_map& new_gather_infos) { + SimplePointwiseGatherActor::PostProcess(graph, current_node, info_without_node, logger, + propagate_input_indices, all_input_cmp_rets, new_gather_infos); + + // Update Softmax's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -566,6 +609,11 @@ bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, return true; } + LOG_DEBUG_INFO(logger, "Skip handle the Reshape, new_shape_const_values[info.non_negative_axis]:" + + std::to_string(new_shape_const_values[info.non_negative_axis]) + + ", info.output_dim_on_axis.has_dim_value(): " + + std::to_string(info.output_dim_on_axis.has_dim_value()) + "."); + return false; } @@ -604,11 +652,12 @@ bool ReshapeGatherActor::PostProcess( return true; } - // If it selected shape is a dim value, we can update the shape tensor directory. + // If the selected shape is a dim value, we can update the shape tensor directory. if (info_without_node.output_dim_on_axis.has_dim_value()) { new_shape_const_values[slice_axis] = info_without_node.output_dim_on_axis.dim_value(); auto new_shape_arg = - CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, new_shape_const_values, + CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, + new_shape_const_values, graph.GenerateNodeArgName(current_node.MutableInputDefs()[1]->Name())); graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg); return true; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h index f6715e4bb1f32..0c21be1397636 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h @@ -189,7 +189,7 @@ class LayerNormalizationGatherActor : public UpStreamGatherOperatorActorBase { const logging::Logger& /* logger */, const std::unordered_map& /* propagate_input_indices */, const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_gather_infos */) override { return true; } + const std::unordered_map& /* new_gather_infos */) override; }; class SoftmaxGatherActor : public SimplePointwiseGatherActor { @@ -202,6 +202,12 @@ class SoftmaxGatherActor : public SimplePointwiseGatherActor { std::unordered_map& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const SliceInfo& /* info_without_node */, + const logging::Logger& /* logger */, + const std::unordered_map& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_gather_infos */) override; }; class ReshapeGatherActor : public UpStreamGatherOperatorActorBase { diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 01016774288e4..a03d0da2538d4 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -638,7 +638,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -737,7 +738,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -826,6 +828,345 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { } } +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 32, 256] (float) + | + LayerNormalization[axis=-1 (as example)] + | + [2, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherLayerNormalization) { + std::vector> test_config_pairs{ + // { + // is_scalar_slice, + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, -3, -3, false}, + {true, -2, -2, false}, + {true, -1, 1, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, -3, -3, false}, + {false, -2, -2, false}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t ln_axis_before = std::get<1>(p); + int64_t ln_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, ln_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + values, require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, ln_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 32, 256}}); + auto* input2_arg = builder.MakeInput({{256}}); + auto* input3_arg = builder.MakeInput({{256}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + std::vector slice_inputs; + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {ln_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 4, 32, 256] (float) + | + Softmax[axis=3 (as example)] + | + [2, 4, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 32, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherSoftmax) { + std::vector> test_config_pairs{ + // {is_scalar_slice, softmax_axis_before_propagation, + // expected_softmax_axis_after_propagation, expected to propagate} + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, 3, 2, true}, + {true, -4, -4, false}, + {true, -3, -3, false}, + {true, -2, 1, true}, + {true, -1, 2, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, 3, 3, true}, + {false, -4, -4, false}, + {false, -3, -3, false}, + {false, -2, -2, true}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t softmax_axis_before = std::get<1>(p); + int64_t softmax_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, softmax_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Softmax") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, + require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == softmax_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 4); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(3)) && + slice_out_shape->dim(3).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, softmax_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 4, 32, 256}}); + auto* softmax_out = builder.MakeIntermediate(); + builder.AddNode("Softmax", {input1_arg}, {softmax_out}) + .AddAttribute("axis", softmax_axis_before); + + std::vector slice_inputs; + + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {softmax_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx"; @@ -835,7 +1176,8 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -928,7 +1270,8 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph);