From ca596113025e5568cb79853125d1d59c7c695605 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 29 Oct 2024 19:50:58 +0800 Subject: [PATCH] reshape related fusion --- .../core/optimizer/matmul_add_fusion.cc | 41 +++- onnxruntime/core/optimizer/reshape_fusion.cc | 71 ++++++ onnxruntime/core/optimizer/reshape_fusion.h | 5 + .../qnn/builder/opbuilder/base_op_builder.cc | 33 --- .../qnn/builder/opbuilder/base_op_builder.h | 82 ------- .../qnn/builder/opbuilder/conv_op_builder.cc | 16 +- .../qnn/builder/opbuilder/gemm_op_builder.cc | 6 +- .../builder/qnn_node_group/qnn_node_group.cc | 2 + .../qnn_node_group/reshape_gemm_fusion.cc | 208 ++++++++++++++++++ .../qnn_node_group/reshape_gemm_fusion.h | 48 ++++ .../core/providers/qnn/builder/qnn_utils.cc | 108 ++++++++- .../core/providers/qnn/builder/qnn_utils.h | 19 ++ .../test/optimizer/graph_transform_test.cc | 32 +++ .../test/providers/qnn/gemm_op_test.cc | 41 ++++ 14 files changed, 576 insertions(+), 136 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 8d9d41b75cb06..8afb554ab0833 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -130,7 +130,6 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } k = dim_k.dim_value(); n = dim_n.dim_value(); - ORT_ENFORCE(shape_values.back() == k); m = std::accumulate(shape_values.begin(), shape_values.end() - 1, static_cast(1), std::multiplies()); } @@ -167,8 +166,10 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } auto gemm_output_defs = add_node.MutableOutputDefs(); + Node* input_node = nullptr; + Node* output_node = nullptr; if (need_reshape) { - auto add_reshape = [&](const std::vector& shape, Graph& graph, bool is_input) { + auto add_reshape = [&](const std::vector& shape, Graph& graph, bool is_input) -> Node* { const std::string name = is_input ? "gemm_input" : "gemm_output"; ONNX_NAMESPACE::TensorProto shape_initializer_proto; shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape")); @@ -187,23 +188,47 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, {is_input ? gemm_input_defs[0] : new_arg, shape_arg}, {is_input ? new_arg : gemm_output_defs[0]}); reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); - return new_arg; + return &reshape_node; }; - gemm_input_defs[0] = add_reshape({m, k}, graph, true); + input_node = add_reshape({m, k}, graph, true); + gemm_input_defs[0] = input_node->MutableOutputDefs()[0]; shape_values.back() = n; - gemm_output_defs[0] = add_reshape(shape_values, graph, false); + output_node = add_reshape(shape_values, graph, false); + gemm_output_defs[0] = output_node->MutableInputDefs()[0]; } Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), "Gemm", "fused Matmul and Add", gemm_input_defs, gemm_output_defs); - - // Assign provider to this new node. Provider should be same as the provider for old node. gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); + if (need_reshape) { + graph.AddEdge(input_node->Index(), gemm_node.Index(), 0, 0); + graph.AddEdge(gemm_node.Index(), output_node->Index(), 0, 0); + } else { + input_node = &gemm_node; + output_node = &gemm_node; + } + + auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(matmul_node); + for (auto cur = matmul_input_edges.cbegin(), end = matmul_input_edges.cend(); cur != end; ++cur) { + if (cur->dst_arg_index == 0) { + graph.AddEdge(cur->src_node, input_node->Index(), cur->src_arg_index, 0); + } else if (cur->dst_arg_index == 1) { + graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1); + } + } + graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges); + auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node); + for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) { + if (cur->dst_arg_index == 1) { + graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2); + } + } + graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges); graph_utils::RemoveNodeOutputEdges(graph, matmul_node); + graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0); graph.RemoveNode(matmul_node.Index()); - graph_utils::RemoveNodeOutputEdges(graph, add_node); graph.RemoveNode(add_node.Index()); modified = true; diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 7f94e18458be2..dbd34169e3389 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -48,6 +48,8 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c fused_count++; LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name(); modified = true; + } else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph, logger)) { + modified = true; } } @@ -452,4 +454,73 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo return true; } +bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + InlinedVector contiguous_reshapes{&reshape}; + InlinedVector shape_value; + while (true) { + Node* p_curr_node = contiguous_reshapes.back(); + if (graph.NodeProducesGraphOutput(*p_curr_node) || p_curr_node->GetOutputEdgesCount() != 1) { + break; + } + + Node* p_next_node = graph.GetNode(p_curr_node->OutputNodesBegin()->Index()); + if (p_next_node->OpType() != "Reshape" && p_next_node->OpType() != "Squeeze" && + p_next_node->OpType() != "Unsqueeze") { + break; + } + + auto shape = p_next_node->OutputDefs()[0]->Shape(); + if (!shape) { + break; + } + + bool is_concrete_shape = true; + shape_value.clear(); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + shape_value.emplace_back(dim.dim_value()); + } else { + is_concrete_shape = false; + } + } + if (!is_concrete_shape) { + break; + } + + contiguous_reshapes.emplace_back(p_next_node); + } + + if (contiguous_reshapes.size() < 2) { + return false; + } + + const std::string& name = contiguous_reshapes[0]->Name(); + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape")); + shape_initializer_proto.add_dims(static_cast(shape_value.size())); + shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t)); + NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name, + {contiguous_reshapes[0]->MutableInputDefs()[0], shape_arg}, + {contiguous_reshapes.back()->MutableOutputDefs()[0]}); + reshape_node.SetExecutionProviderType(contiguous_reshapes[0]->GetExecutionProviderType()); + + auto input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*contiguous_reshapes[0]); + for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) { + if (cur->dst_arg_index == 0) { + graph.AddEdge(cur->src_node, reshape_node.Index(), cur->src_arg_index, 0); + } + } + graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges); + graph_utils::ReplaceDownstreamNodeInput(graph, *contiguous_reshapes.back(), 0, reshape_node, 0); + for (Node* p_node : contiguous_reshapes) { + graph_utils::RemoveNodeOutputEdges(graph, *p_node); + graph.RemoveNode(p_node->Index()); + } + + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/reshape_fusion.h b/onnxruntime/core/optimizer/reshape_fusion.h index f236b516ad9be..f5eb8be48e84c 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.h +++ b/onnxruntime/core/optimizer/reshape_fusion.h @@ -27,6 +27,11 @@ class ReshapeFusion : public GraphTransformer { static bool Is_One_Element_Input(const Node& cur_node, int index); static bool Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& root_input, const Node& concat, int index, gsl::span shape_value, const logging::Logger& logger); + + // Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete. + // For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output, + // so this fusion can help to reduce memory usage on such devices. + static bool FuseContiguousReshapes(Node& reshape, Graph& graph, const logging::Logger& logger); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index ed70111087e19..67d7d976b04a8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -7,8 +7,6 @@ #include #include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/cpu/tensor/transpose.h" #include "core/common/safeint.h" namespace onnxruntime { @@ -271,37 +269,6 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& return Status::OK(); } -Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - const std::vector& perm, - std::vector& transposed_data) const { - const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType(); - const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); - TensorShape tensor_shape{tensor_shape_dims}; - AllocatorPtr cpu_allocator = std::make_shared(); - Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator); - - auto rank = perm.size(); - std::vector new_tensor_shape_dims; - std::vector permutations; - new_tensor_shape_dims.reserve(rank); - permutations.reserve(rank); - for (int64_t p : perm) { - permutations.push_back(p); - new_tensor_shape_dims.push_back(tensor_shape_dims[p]); - } - - TensorShape new_tensor_shape(new_tensor_shape_dims); - Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( - Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor)); - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); - onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); - - return Status::OK(); -} - Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, Qnn_Scalar_t& axis_qnn_scalar, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 055c0f6ccf2fa..653195d440a84 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -214,88 +214,6 @@ class BaseOpBuilder : public IOpBuilder { return it->second; } - // NCHW shape to channel last - Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape) const { - ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); - nhwc_shape[0] = nchw_shape[0]; - nhwc_shape[1] = nchw_shape[2]; - nhwc_shape[2] = nchw_shape[3]; - nhwc_shape[3] = nchw_shape[1]; - - return Status::OK(); - } - - // NCHW shape to HWCN shape, required for Conv weight - Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape) const { - if (nchw_shape.size() == 4) { - hwcn_shape[0] = nchw_shape[2]; - hwcn_shape[1] = nchw_shape[3]; - hwcn_shape[2] = nchw_shape[1]; - hwcn_shape[3] = nchw_shape[0]; - } else if (nchw_shape.size() == 5) { - hwcn_shape[0] = nchw_shape[2]; - hwcn_shape[1] = nchw_shape[3]; - hwcn_shape[2] = nchw_shape[4]; - hwcn_shape[3] = nchw_shape[1]; - hwcn_shape[4] = nchw_shape[0]; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); - } - - return Status::OK(); - } - - // CNHW shape to HWCN shape, required for Conv weight - Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape) const { - if (cnhw_shape.size() == 4) { - hwcn_shape[0] = cnhw_shape[2]; - hwcn_shape[1] = cnhw_shape[3]; - hwcn_shape[2] = cnhw_shape[0]; - hwcn_shape[3] = cnhw_shape[1]; - } else if (cnhw_shape.size() == 5) { - hwcn_shape[0] = cnhw_shape[2]; - hwcn_shape[1] = cnhw_shape[3]; - hwcn_shape[2] = cnhw_shape[4]; - hwcn_shape[3] = cnhw_shape[0]; - hwcn_shape[4] = cnhw_shape[1]; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); - } - - return Status::OK(); - } - Status TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - const std::vector& perm, - std::vector& transposed_data) const; - - Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - std::vector& transposed_data, - bool is_3d = false) const { - auto& perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm; - return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); - } - - Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - std::vector& transposed_data, - bool is_3d = false) const { - auto& perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm; - return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); - } - - Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, - std::vector& data_shape, - const onnx::TensorProto& initializer, - std::vector& transposed_data) const { - auto tmp = data_shape[0]; - data_shape[0] = data_shape[1]; - data_shape[1] = tmp; - std::vector two_dim_trans_perm{1, 0}; - return TransposeInitializer(qnn_model_wrapper, initializer, two_dim_trans_perm, transposed_data); - } - // Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end] void ReArranagePads(std::vector& pads) const { auto pads_size = pads.size(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 12887f0fb72d6..f50d9f477cc1f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -211,9 +211,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, // Change shape to HWCN, it could be initializer or normal input if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(NchwShapeToHwcn(input_info.shape, actual_shape)); + ORT_RETURN_IF_ERROR(utils::NchwShapeToHwcn(input_info.shape, actual_shape)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(input_info.shape, actual_shape)); + ORT_RETURN_IF_ERROR(utils::CnhwShapeToHwcn(input_info.shape, actual_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -224,9 +224,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, if (input_info.is_initializer) { // Get transposed initializer bytes. if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); + ORT_RETURN_IF_ERROR(utils::TransposeFromNchwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); + ORT_RETURN_IF_ERROR(utils::TransposeFromCnhwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -413,9 +413,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // Create the final shape after the weights are transposed to HWCN. if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(NchwShapeToHwcn(shape_2d, final_shape)); + ORT_RETURN_IF_ERROR(utils::NchwShapeToHwcn(shape_2d, final_shape)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(shape_2d, final_shape)); + ORT_RETURN_IF_ERROR(utils::CnhwShapeToHwcn(shape_2d, final_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -453,9 +453,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // Get transposed initializer bytes. // if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); + ORT_RETURN_IF_ERROR(utils::TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); + ORT_RETURN_IF_ERROR(utils::TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc index eeee26c177281..754ea77b46511 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc @@ -113,10 +113,8 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); if (1 == input_trans_flag.at(input_i)) { ORT_RETURN_IF_ERROR(quantize_param.HandleTranspose(std::vector({1, 0}))); - ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper, - input_shape, - *input_tensor, - unpacked_tensor)); + ORT_RETURN_IF_ERROR( + utils::TwoDimensionTranspose(qnn_model_wrapper, input_shape, *input_tensor, unpacked_tensor)); } else { ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 9fb9e815321c0..cf549148550f1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -18,6 +18,7 @@ #include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" namespace onnxruntime { namespace qnn { @@ -92,6 +93,7 @@ static std::unique_ptr TryQnnFusions( {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, {"Conv", ConvActivationFusion::TryFusion}, {"ConvTranspose", ConvActivationFusion::TryFusion}, + {"Gemm", ReshapeGemmFusion::TryFusion}, }; // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc new file mode 100644 index 0000000000000..93dd04c5200fa --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc @@ -0,0 +1,208 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" + +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace qnn { + +static const NodeUnit* GetReshapeNodeUnit( + const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Node& gemm_node) { + if (gemm_node.OpType() != "Gemm") { + return nullptr; + } + for (auto it = gemm_node.InputEdgesBegin(); it != gemm_node.InputEdgesEnd(); it++) { + if (it->GetDstArgIndex() == 0) { + const Node& reshape_node = it->GetNode(); + if (reshape_node.OpType() == "Reshape" && !graph_viewer.NodeProducesGraphOutput(reshape_node) && + reshape_node.GetOutputEdgesCount() == 1) { + const auto it = node_to_node_unit.find(&reshape_node); + if (it != node_to_node_unit.end()) { + const NodeUnit* reshape_node_unit = it->second; + if (reshape_node_unit && node_unit_to_qnn_node_group.count(reshape_node_unit) == 0 && + reshape_node_unit->UnitType() == NodeUnit::Type::SingleNode) { + return reshape_node_unit; + } + } + } + } + } + return nullptr; +} + +static bool CheckShape(const GraphViewer& graph_viewer, const Node& reshape_node) { + auto tensor_shape = reshape_node.InputDefs()[0]->Shape(); + if (!tensor_shape) return false; + InlinedVector input_shape; + for (const auto& dim : tensor_shape->dim()) { + if (dim.value_case() != ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue) return false; + input_shape.emplace_back(dim.dim_value()); + } + + const ONNX_NAMESPACE::TensorProto* shape_proto = + graph_viewer.GetConstantInitializer(reshape_node.InputDefs()[1]->Name()); + if (!shape_proto) return false; + const auto* dtype = DataTypeImpl::TensorTypeFromONNXEnum(shape_proto->data_type())->GetElementType(); + TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(*shape_proto); + Tensor tensor(dtype, shape, std::make_shared()); + if (onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), graph_viewer.ModelPath(), *shape_proto, + tensor) != Status::OK()) { + return false; + } + + InlinedVector output_shape; + if (tensor.IsDataType()) { + gsl::span tensor_elems = tensor.DataAsSpan(); + output_shape.insert(output_shape.end(), tensor_elems.begin(), tensor_elems.end()); + } else if (tensor.IsDataType()) { + gsl::span tensor_elems = tensor.DataAsSpan(); + for (int32_t elem : tensor_elems) { + output_shape.emplace_back(static_cast(elem)); + } + } + + return !input_shape.empty() && output_shape.size() == 2 && input_shape.back() == output_shape.back(); +} + +#define ValidateOnQnn(qnn_model_wrapper, reshape_node_unit, gemm_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (reshape_node_unit), (gemm_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, reshape_node_unit, gemm_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (reshape_node_unit), (gemm_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& reshape_node_unit, + const NodeUnit& gemm_node_unit, bool validate) { + assert(reshape_node_unit.OpType() == "Reshape" && gemm_node_unit.OpType() == "Gemm"); + const auto& node_name = utils::GetNodeName(gemm_node_unit); + const NodeUnitIODef& input_def = reshape_node_unit.Inputs()[0]; + const NodeUnitIODef& weight_def = gemm_node_unit.Inputs()[1]; + const NodeUnitIODef* bias_def_ptr = nullptr; + bool has_bias = gemm_node_unit.Inputs().size() == 3; + if (has_bias) { + bias_def_ptr = &gemm_node_unit.Inputs()[2]; + } + const NodeUnitIODef& output_def = gemm_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper bias_tensor; + QnnTensorWrapper output_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + std::vector weight_shape; + std::vector unpacked_tensor; + std::string weight_tensor_name = weight_def.node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(weight_def.node_arg, weight_shape), "Failed to get weight shape"); + Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(weight_tensor_name); + Qnn_DataType_t data_type = QNN_DATATYPE_FLOAT_32; + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, weight_def.node_arg.TypeAsProto(), data_type)); + const auto& weight_tensor_proto = qnn_model_wrapper.GetInitializerTensors().at(weight_tensor_name); + ORT_RETURN_IF_ERROR( + utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor)); + QnnTensorWrapper weight_tensor(weight_tensor_name, tensor_type, data_type, QnnQuantParamsWrapper(), + std::move(weight_shape), std::move(unpacked_tensor)); + if (has_bias) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(*bias_def_ptr, bias_tensor)); + } + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + std::vector input_tensors = {input_tensor.GetQnnTensor(), weight_tensor.GetQnnTensor()}; + if (has_bias) { + input_tensors.emplace_back(bias_tensor.GetQnnTensor()); + } + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_FULLY_CONNECTED, std::move(input_tensors), + {output_tensor.GetQnnTensor()}, {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor)), "Failed to add weight"); + if (has_bias) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(bias_tensor)), "Failed to add bias"); + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + std::vector input_names = {input_def.node_arg.Name(), weight_tensor_name}; + if (has_bias) { + input_names.emplace_back(bias_def_ptr->node_arg.Name()); + } + ORT_RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_FULLY_CONNECTED, + std::move(input_names), {output_def.node_arg.Name()}, {}, validate), + "Failed to add fused Gemm node."); + } + return Status::OK(); +} + +std::unique_ptr ReshapeGemmFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, const NodeUnit& gemm_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + if (gemm_node_unit.OpType() != "Gemm" || gemm_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& gemm_node = gemm_node_unit.GetNode(); + NodeAttrHelper helper(gemm_node); + auto transA = helper.Get("transA", static_cast(0)); + auto transB = helper.Get("transB", static_cast(0)); + const auto& weight_input = gemm_node_unit.Inputs()[1]; + // The pattern is from MatMul->Add, so the transA and transB should be false, and weight should be initializer. + // Currently we don't handle quantized weight. + if (transA != 0 || transB != 0 || !qnn_model_wrapper.IsInitializerInput(weight_input.node_arg.Name()) || + weight_input.quant_param.has_value()) { + return nullptr; + } + + const NodeUnit* reshape_node_unit = + GetReshapeNodeUnit(graph_viewer, node_to_node_unit, node_unit_to_qnn_node_group, gemm_node); + if (!reshape_node_unit) { + return nullptr; + } + + if (!CheckShape(graph_viewer, reshape_node_unit->GetNode())) { + return nullptr; + } + + return std::make_unique(*reshape_node_unit, gemm_node_unit); +} + +ReshapeGemmFusion::ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const NodeUnit& gemm_node_unit) + : node_units_{} { + node_units_[0] = &reshape_node_unit; + node_units_[1] = &gemm_node_unit; +} + +Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span ReshapeGemmFusion::GetNodeUnits() const { + return gsl::make_span(node_units_.data(), 2); +} + +const NodeUnit* ReshapeGemmFusion::GetTargetNodeUnit() const { return node_units_[1]; } + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h new file mode 100644 index 0000000000000..a07e8de5a8e2b --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a Reshape->Gemm sequence to a single Gemm node. +/// Ideally Reshape->Gemm->Reshape should be fused to a single Gemm node with keep_dims set to True, +/// but on some devices the OpConfig validation will fail when keep_dims to True (it says expected value is 0), +/// so we still need to keep the 2nd Reshape node. +/// +class ReshapeGemmFusion : public IQnnNodeGroup { + public: + ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const NodeUnit& gemm_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ReshapeGemmFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ReshapeGemmFusion"; } + + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, const NodeUnit& gemm_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 8d2cb5bdb6da0..67026cf44fa69 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/qnn/builder/qnn_utils.h" + #include #include #include @@ -9,7 +11,8 @@ #include "core/common/common.h" #include "core/framework/data_types.h" -#include "qnn_utils.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/cpu/tensor/transpose.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -570,6 +573,109 @@ Status Quantize(const double double_value, return Status::OK(); } +// NCHW shape to channel last +Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape) { + ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); + nhwc_shape[0] = nchw_shape[0]; + nhwc_shape[1] = nchw_shape[2]; + nhwc_shape[2] = nchw_shape[3]; + nhwc_shape[3] = nchw_shape[1]; + + return Status::OK(); +} + +// NCHW shape to HWCN shape, required for Conv weight +Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape) { + if (nchw_shape.size() == 4) { + hwcn_shape[0] = nchw_shape[2]; + hwcn_shape[1] = nchw_shape[3]; + hwcn_shape[2] = nchw_shape[1]; + hwcn_shape[3] = nchw_shape[0]; + } else if (nchw_shape.size() == 5) { + hwcn_shape[0] = nchw_shape[2]; + hwcn_shape[1] = nchw_shape[3]; + hwcn_shape[2] = nchw_shape[4]; + hwcn_shape[3] = nchw_shape[1]; + hwcn_shape[4] = nchw_shape[0]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); + } + + return Status::OK(); +} + +// CNHW shape to HWCN shape, required for Conv weight +Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape) { + if (cnhw_shape.size() == 4) { + hwcn_shape[0] = cnhw_shape[2]; + hwcn_shape[1] = cnhw_shape[3]; + hwcn_shape[2] = cnhw_shape[0]; + hwcn_shape[3] = cnhw_shape[1]; + } else if (cnhw_shape.size() == 5) { + hwcn_shape[0] = cnhw_shape[2]; + hwcn_shape[1] = cnhw_shape[3]; + hwcn_shape[2] = cnhw_shape[4]; + hwcn_shape[3] = cnhw_shape[0]; + hwcn_shape[4] = cnhw_shape[1]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); + } + + return Status::OK(); +} + +namespace { +Status TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + const std::vector& perm, std::vector& transposed_data) { + const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType(); + const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); + TensorShape tensor_shape{tensor_shape_dims}; + AllocatorPtr cpu_allocator = std::make_shared(); + Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator); + + auto rank = perm.size(); + std::vector new_tensor_shape_dims; + std::vector permutations; + new_tensor_shape_dims.reserve(rank); + permutations.reserve(rank); + for (int64_t p : perm) { + permutations.push_back(p); + new_tensor_shape_dims.push_back(tensor_shape_dims[p]); + } + + TensorShape new_tensor_shape(new_tensor_shape_dims); + Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); + ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( + Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor)); + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); + onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); + + return Status::OK(); +} +} // namespace + +Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d) { + auto& perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm; + return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); +} + +Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d) { + auto& perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm; + return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); +} + +Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, std::vector& data_shape, + const onnx::TensorProto& initializer, std::vector& transposed_data) { + auto tmp = data_shape[0]; + data_shape[0] = data_shape[1]; + data_shape[1] = tmp; + std::vector two_dim_trans_perm{1, 0}; + return TransposeInitializer(qnn_model_wrapper, initializer, two_dim_trans_perm, transposed_data); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index aa4a27460563f..76743f4b8c69b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -12,6 +12,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/node_unit.h" #include "core/util/qmath.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" namespace onnxruntime { namespace qnn { @@ -104,6 +105,24 @@ Status Quantize(const double double_value, const Qnn_DataType_t qnn_data_type, int& quant_value); +// NCHW shape to channel last +Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape); + +// NCHW shape to HWCN shape, required for Conv weight +Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape); + +// CNHW shape to HWCN shape, required for Conv weight +Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape); + +Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d = false); + +Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d = false); + +Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, std::vector& data_shape, + const onnx::TensorProto& initializer, std::vector& transposed_data); + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6448961df9331..8d8a32796806e 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4110,6 +4110,38 @@ TEST_F(GraphTransformationTests, ReshapeFusionDistilBertTest) { } } +TEST_F(GraphTransformationTests, ReshapeFusion_Contiguous_Reshape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{8, 16, 32}}); + auto* shape_initializer = builder.MakeInitializer({4}, {2, 4, 16, 32}); + auto* axes_initializer = builder.MakeInitializer({1}, {1}); + auto* reshape_out = builder.MakeIntermediate(); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + builder.AddNode("Reshape", {input_arg, shape_initializer}, {reshape_out}); + builder.AddNode("Unsqueeze", {reshape_out, axes_initializer}, {unsqueeze_out}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + // Test eliminating redundant Concat-Slice pattern. TEST_F(GraphTransformationTests, ConcatSliceEliminationTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "concat_slice_basic_test.onnx"; diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index 33c868694c9c0..6ebc02f8ad803 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -177,6 +177,47 @@ TEST_F(QnnCPUBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { ExpectedEPNodeAssignment::All); } +namespace { +GetTestModelFn BuildReshapeGemmTestCase(const TestInputDef& input, const TestInputDef& shape, + const TestInputDef& weight, const TestInputDef& bias) { + return [&](ModelTestBuilder& builder) { + std::vector reshape_inputs = {MakeTestInput(builder, input), + MakeTestInput(builder, shape)}; + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", reshape_inputs, {reshape_output}); + NodeArg* output = builder.MakeOutput(); + std::vector gemm_inputs = {reshape_output, MakeTestInput(builder, weight), + MakeTestInput(builder, bias)}; + builder.AddNode("Gemm", gemm_inputs, {output}); + }; +} + +void RunCPUReshapeGemmTest(const TestInputDef& input, const TestInputDef& shape, + const TestInputDef& weight, const TestInputDef& bias, + ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + auto build_fn = BuildReshapeGemmTestCase(input, shape, weight, bias); + RunQnnModelTest(build_fn, provider_options, 18, expected_ep_assignment, fp32_abs_err); +} + +} // namespace + +TEST_F(QnnCPUBackendTests, ReshapeGemmFusion) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector shape_data = {4, 2}; + std::vector weight_data(6, 1.0f); + std::vector bias_data = {1.0f, 2.0f, 3.0f}; + RunCPUReshapeGemmTest(TestInputDef({2, 2, 2}, false, input_data), TestInputDef({2}, true, shape_data), + TestInputDef({2, 3}, true, weight_data), TestInputDef({3}, true, bias_data), + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: