diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 2a4916ccb324a..4cffd54951850 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -11,10 +11,63 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; namespace onnxruntime { +namespace { + +// Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it. +// In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output. +// Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly +// if we go to other MatMul-Add pairs in the same pattern. +struct AttentionPatternCache { + bool IsAttentionPattern(const Graph& graph, const Node& matmul_node, const Node& add_node) { + const Node* parent_node = graph.GetProducerNode(matmul_node.InputDefs()[0]->Name()); + if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&add_node) > 0) { + return true; + } + + if (parent_node && parent_node->OpType() == "LayerNormalization") { + unsigned int add_count = 0; + unsigned int matmul_count = 0; + unsigned int shape_count = 0; + const Node* ln_add_node = nullptr; + for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) { + std::string op_type = (*it).OpType(); + if (op_type == "Add") { + ln_add_node = &(*it); + add_count++; + } else if (op_type == "MatMul") { + matmul_count++; + } else if (op_type == "Shape") { + shape_count++; + } + } + + if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) { + size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0; + const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name()); + if (attn_add_node && attn_add_node->OpType() == "Add") { + attn_ln_nodes.insert(parent_node); + attn_add_nodes.insert(attn_add_node); + return true; + } + } + } + + return false; + } + + std::unordered_set attn_ln_nodes; + std::unordered_set attn_add_nodes; +}; + +} // namespace + Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + // Cache for skipping Attention subgraph pattern. + AttentionPatternCache attn_pattern_cache; + for (auto node_index : node_topology_list) { auto* node_ptr = graph.GetNode(node_index); if (!node_ptr) @@ -65,26 +118,45 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Gemm only support Matrix, need to check the shape of MatMul and Add auto matmul_a_shape = matmul_input_defs[0]->Shape(); auto matmul_b_shape = matmul_input_defs[1]->Shape(); - if (nullptr == matmul_a_shape || nullptr == matmul_b_shape) { + if (nullptr == matmul_a_shape || nullptr == matmul_b_shape || matmul_b_shape->dim_size() != 2) { continue; } - if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) { - // Gemm only support Matrix - continue; + bool need_reshape = matmul_a_shape->dim_size() != 2; + const auto& dim_n = matmul_b_shape->dim(1); + InlinedVector shape_values; + int64_t m = 0, k = 0, n = 0; + if (need_reshape) { + // Only check and skip Attention pattern here because normally input to Attention is 4D. + if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) { + continue; + } + + // Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require + // both inputs have concrete shape for now, we can add dynamic shape support in future. + auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape); + if (a_shape.Size() == -1) { + continue; + } + + const auto& dim_k = matmul_b_shape->dim(0); + if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) { + continue; + } + + shape_values = a_shape.AsShapeVector(); + // If a_shape is 1D, m is 1 from SizeToDimension() with empty dimension interval. + m = a_shape.SizeToDimension(a_shape.NumDimensions() - 1); + k = dim_k.dim_value(); + n = dim_n.dim_value(); } const auto& matmul_output = *matmul_node.OutputDefs()[0]; auto matmul_output_name = matmul_output.Name(); auto gemm_input_defs = matmul_input_defs; - if (matmul_output_name == add_input_defs[0]->Name()) { - // matmul output as Add_A, should use Add_B as input C for gemm - gemm_input_defs.push_back(add_input_defs[1]); - } else { - // matmul output as Add_B, should use Add_A as input C for gemm - gemm_input_defs.push_back(add_input_defs[0]); - } + int bias_idx = matmul_output_name == add_input_defs[0]->Name() ? 1 : 0; + gemm_input_defs.push_back(add_input_defs[bias_idx]); // valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as // GEMM only supports unidirectional broadcast on the bias input C @@ -92,31 +164,87 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } const auto& bias_shape = *gemm_input_defs.back()->Shape(); - const auto& M = matmul_output.Shape()->dim()[0]; - const auto& N = matmul_output.Shape()->dim()[1]; auto dim_has_value_1 = [](const TensorShapeProto_Dimension& dim) { return dim.has_dim_value() && dim.dim_value() == 1; }; - bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim()[0] == N) || - (bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim()[0]) && bias_shape.dim()[1] == N) || - (bias_shape.dim_size() == 2 && bias_shape.dim()[0] == M && - (dim_has_value_1(bias_shape.dim()[1]) || bias_shape.dim()[1] == N))); + bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) || + (!need_reshape && bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) && + bias_shape.dim(1) == dim_n) || + (!need_reshape && bias_shape.dim_size() == 2 && bias_shape.dim(0) == matmul_a_shape->dim(0) && + (dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n))); if (!valid) { continue; } - Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"), - "Gemm", - "fused Matmul and Add " + add_node.OpType(), - gemm_input_defs, - {}); + auto gemm_output_defs = add_node.MutableOutputDefs(); + Node* input_node = nullptr; + Node* output_node = nullptr; + if (need_reshape) { + auto add_reshape = [&](const InlinedVector& shape, Graph& graph, bool is_input) -> Node* { + const std::string name = is_input ? "gemm_input" : "gemm_output"; + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape")); + shape_initializer_proto.add_dims(static_cast(shape.size())); + shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + shape_initializer_proto.set_raw_data(shape.data(), shape.size() * sizeof(int64_t)); + NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + ONNX_NAMESPACE::TypeProto new_arg_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type()); + new_arg_type.mutable_tensor_type()->set_elem_type(element_type); + new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(m); + new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(is_input ? k : n); + NodeArg* new_arg = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name + "_reshape_arg"), &new_arg_type); + Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_reshape"), "Reshape", "Reshape for " + name, + {is_input ? gemm_input_defs[0] : new_arg, shape_arg}, + {is_input ? new_arg : gemm_output_defs[0]}); + reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); + return &reshape_node; + }; + + input_node = add_reshape({m, k}, graph, true); + gemm_input_defs[0] = input_node->MutableOutputDefs()[0]; + shape_values.back() = n; + output_node = add_reshape(shape_values, graph, false); + gemm_output_defs[0] = output_node->MutableInputDefs()[0]; + } - // Assign provider to this new node. Provider should be same as the provider for old node. + Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion"), "Gemm", + "fused Matmul and Add", gemm_input_defs, gemm_output_defs); gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType()); - // move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node. - graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node); + if (need_reshape) { + graph.AddEdge(input_node->Index(), gemm_node.Index(), 0, 0); + graph.AddEdge(gemm_node.Index(), output_node->Index(), 0, 0); + } else { + input_node = &gemm_node; + output_node = &gemm_node; + } + + auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(matmul_node); + for (auto cur = matmul_input_edges.cbegin(), end = matmul_input_edges.cend(); cur != end; ++cur) { + if (cur->dst_arg_index == 0) { + graph.AddEdge(cur->src_node, input_node->Index(), cur->src_arg_index, 0); + } else if (cur->dst_arg_index == 1) { + graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1); + } + } + + graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges); + auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node); + for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) { + if (cur->dst_arg_index == bias_idx) { + graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2); + break; + } + } + + graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges); + graph_utils::RemoveNodeOutputEdges(graph, matmul_node); + graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0); + graph.RemoveNode(matmul_node.Index()); + graph.RemoveNode(add_node.Index()); modified = true; } diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 7f94e18458be2..35ab127284fd5 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -48,6 +48,8 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c fused_count++; LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name(); modified = true; + } else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph)) { + modified = true; } } @@ -452,4 +454,53 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo return true; } +bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { + InlinedVector> contiguous_reshapes{reshape}; + InlinedVector shape_value; + while (true) { + Node& curr_node = contiguous_reshapes.back(); + if (graph.NodeProducesGraphOutput(curr_node) || curr_node.GetOutputEdgesCount() != 1) { + break; + } + + Node* next_node = graph.GetNode(curr_node.OutputNodesBegin()->Index()); + if (next_node->OpType() != "Reshape" && next_node->OpType() != "Squeeze" && next_node->OpType() != "Unsqueeze") { + break; + } + + auto shape = next_node->OutputDefs()[0]->Shape(); + if (!shape) { + break; + } + + auto tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*shape); + if (tensor_shape.Size() == -1) { + break; + } + + shape_value = tensor_shape.AsShapeVector(); + contiguous_reshapes.emplace_back(*next_node); + } + + if (contiguous_reshapes.size() < 2) { + return false; + } + + const std::string& name = contiguous_reshapes[0].get().Name(); + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape")); + shape_initializer_proto.add_dims(static_cast(shape_value.size())); + shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t)); + NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name, + {contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg}, + {contiguous_reshapes.back().get().MutableOutputDefs()[0]}); + reshape_node.SetExecutionProviderType(contiguous_reshapes[0].get().GetExecutionProviderType()); + + graph_utils::FinalizeNodeFusion(graph, contiguous_reshapes, reshape_node); + + return true; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/reshape_fusion.h b/onnxruntime/core/optimizer/reshape_fusion.h index f236b516ad9be..a2866aa3412fb 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.h +++ b/onnxruntime/core/optimizer/reshape_fusion.h @@ -27,6 +27,11 @@ class ReshapeFusion : public GraphTransformer { static bool Is_One_Element_Input(const Node& cur_node, int index); static bool Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& root_input, const Node& concat, int index, gsl::span shape_value, const logging::Logger& logger); + + // Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete. + // For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output, + // so this fusion can help to reduce memory usage on such devices. + static bool FuseContiguousReshapes(Node& reshape, Graph& graph); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index ed70111087e19..67d7d976b04a8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -7,8 +7,6 @@ #include #include "core/providers/shared/utils/utils.h" -#include "core/framework/tensorprotoutils.h" -#include "core/providers/cpu/tensor/transpose.h" #include "core/common/safeint.h" namespace onnxruntime { @@ -271,37 +269,6 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& return Status::OK(); } -Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - const std::vector& perm, - std::vector& transposed_data) const { - const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType(); - const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); - TensorShape tensor_shape{tensor_shape_dims}; - AllocatorPtr cpu_allocator = std::make_shared(); - Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator); - - auto rank = perm.size(); - std::vector new_tensor_shape_dims; - std::vector permutations; - new_tensor_shape_dims.reserve(rank); - permutations.reserve(rank); - for (int64_t p : perm) { - permutations.push_back(p); - new_tensor_shape_dims.push_back(tensor_shape_dims[p]); - } - - TensorShape new_tensor_shape(new_tensor_shape_dims); - Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); - ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( - Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor)); - ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); - onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); - - return Status::OK(); -} - Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, Qnn_Scalar_t& axis_qnn_scalar, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 055c0f6ccf2fa..653195d440a84 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -214,88 +214,6 @@ class BaseOpBuilder : public IOpBuilder { return it->second; } - // NCHW shape to channel last - Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape) const { - ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); - nhwc_shape[0] = nchw_shape[0]; - nhwc_shape[1] = nchw_shape[2]; - nhwc_shape[2] = nchw_shape[3]; - nhwc_shape[3] = nchw_shape[1]; - - return Status::OK(); - } - - // NCHW shape to HWCN shape, required for Conv weight - Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape) const { - if (nchw_shape.size() == 4) { - hwcn_shape[0] = nchw_shape[2]; - hwcn_shape[1] = nchw_shape[3]; - hwcn_shape[2] = nchw_shape[1]; - hwcn_shape[3] = nchw_shape[0]; - } else if (nchw_shape.size() == 5) { - hwcn_shape[0] = nchw_shape[2]; - hwcn_shape[1] = nchw_shape[3]; - hwcn_shape[2] = nchw_shape[4]; - hwcn_shape[3] = nchw_shape[1]; - hwcn_shape[4] = nchw_shape[0]; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); - } - - return Status::OK(); - } - - // CNHW shape to HWCN shape, required for Conv weight - Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape) const { - if (cnhw_shape.size() == 4) { - hwcn_shape[0] = cnhw_shape[2]; - hwcn_shape[1] = cnhw_shape[3]; - hwcn_shape[2] = cnhw_shape[0]; - hwcn_shape[3] = cnhw_shape[1]; - } else if (cnhw_shape.size() == 5) { - hwcn_shape[0] = cnhw_shape[2]; - hwcn_shape[1] = cnhw_shape[3]; - hwcn_shape[2] = cnhw_shape[4]; - hwcn_shape[3] = cnhw_shape[0]; - hwcn_shape[4] = cnhw_shape[1]; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); - } - - return Status::OK(); - } - Status TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - const std::vector& perm, - std::vector& transposed_data) const; - - Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - std::vector& transposed_data, - bool is_3d = false) const { - auto& perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm; - return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); - } - - Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, - const onnx::TensorProto& initializer, - std::vector& transposed_data, - bool is_3d = false) const { - auto& perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm; - return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); - } - - Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, - std::vector& data_shape, - const onnx::TensorProto& initializer, - std::vector& transposed_data) const { - auto tmp = data_shape[0]; - data_shape[0] = data_shape[1]; - data_shape[1] = tmp; - std::vector two_dim_trans_perm{1, 0}; - return TransposeInitializer(qnn_model_wrapper, initializer, two_dim_trans_perm, transposed_data); - } - // Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end] void ReArranagePads(std::vector& pads) const { auto pads_size = pads.size(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc index d3bdee02437e4..b770ad15b7196 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc @@ -54,7 +54,7 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::vector unpacked_tensor; bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 12887f0fb72d6..f50d9f477cc1f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -211,9 +211,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, // Change shape to HWCN, it could be initializer or normal input if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(NchwShapeToHwcn(input_info.shape, actual_shape)); + ORT_RETURN_IF_ERROR(utils::NchwShapeToHwcn(input_info.shape, actual_shape)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(input_info.shape, actual_shape)); + ORT_RETURN_IF_ERROR(utils::CnhwShapeToHwcn(input_info.shape, actual_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -224,9 +224,9 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, if (input_info.is_initializer) { // Get transposed initializer bytes. if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); + ORT_RETURN_IF_ERROR(utils::TransposeFromNchwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); + ORT_RETURN_IF_ERROR(utils::TransposeFromCnhwToHwcn(qnn_model_wrapper, *input_info.initializer_tensor, unpacked_tensor, is_3d)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -413,9 +413,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // Create the final shape after the weights are transposed to HWCN. if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(NchwShapeToHwcn(shape_2d, final_shape)); + ORT_RETURN_IF_ERROR(utils::NchwShapeToHwcn(shape_2d, final_shape)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(CnhwShapeToHwcn(shape_2d, final_shape)); + ORT_RETURN_IF_ERROR(utils::CnhwShapeToHwcn(shape_2d, final_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } @@ -453,9 +453,9 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // Get transposed initializer bytes. // if (conv_type == OnnxConvType::kConv) { - ORT_RETURN_IF_ERROR(TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); + ORT_RETURN_IF_ERROR(utils::TransposeFromNchwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); } else if (conv_type == OnnxConvType::kConvTranspose) { - ORT_RETURN_IF_ERROR(TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); + ORT_RETURN_IF_ERROR(utils::TransposeFromCnhwToHwcn(qnn_model_wrapper, reshaped_initializer, unpacked_tensor)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 64f676aaa9875..9ccff6e66ce74 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -63,7 +63,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, shape), "Cannot get shape"); uint32_t shape_rank = shape[0]; std::vector unpacked_tensor; - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); const int64_t* shape_data_int64 = reinterpret_cast(unpacked_tensor.data()); std::vector input_shape(shape_rank, 0); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc index eeee26c177281..e95fcda5496e6 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc @@ -110,13 +110,11 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::vector unpacked_tensor; bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name); if (1 == input_trans_flag.at(input_i)) { ORT_RETURN_IF_ERROR(quantize_param.HandleTranspose(std::vector({1, 0}))); - ORT_RETURN_IF_ERROR(TwoDimensionTranspose(qnn_model_wrapper, - input_shape, - *input_tensor, - unpacked_tensor)); + ORT_RETURN_IF_ERROR( + utils::TwoDimensionTranspose(qnn_model_wrapper, input_shape, *input_tensor, unpacked_tensor)); } else { ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 5fc6d42a8a179..48b2a150fdc36 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -187,7 +187,7 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap const auto& pads_input_name = inputs[1].node_arg.Name(); std::vector unpacked_tensor; - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(pads_input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(pads_input_name); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); // Onnx Pads are int64, Qnn use uint32 const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc index 77bc58bd6f833..fbf39e03f63b6 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc @@ -135,7 +135,7 @@ Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const Nod } // Get axes initializer bytes. - const auto& axes_tensor = qnn_model_wrapper.GetInitializerTensors().at(axes_input_name); + const auto& axes_tensor = qnn_model_wrapper.GetInitializerTensor(axes_input_name); std::vector axes_bytes; ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*axes_tensor, axes_bytes)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index ba5ad2cf03cef..ed217ef55403b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -88,7 +88,7 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); if (is_initializer_input) { std::vector unpacked_tensor; - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); size_t tensor_byte_size = unpacked_tensor.size(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc index 851ca84dce075..bd66d030d492d 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc @@ -69,7 +69,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name(); std::vector unpacked_tensor; - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(repeats_input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(repeats_input_name); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); // Onnx repeats are int64, Qnn use uint32 const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index d22c0811682d0..4482658cd94e7 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -94,7 +94,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); if (is_initializer_input) { std::vector unpacked_tensor; - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); + const auto& input_tensor = qnn_model_wrapper.GetInitializerTensor(input_name); ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); k = static_cast(*tensor_data); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 2c7f3c8b22ddd..0cc5565a57a3b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -489,7 +489,7 @@ Status QnnModelWrapper::GetTensorInfo(const NodeUnitIODef& input, TensorInfo& te // Fill in initializer info. tensor_info.is_initializer = IsInitializerInput(name); if (tensor_info.is_initializer) { - tensor_info.initializer_tensor = GetInitializerTensors().at(name); + tensor_info.initializer_tensor = GetInitializerTensor(name); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index f3e52050e79e0..7106e5bca9e4a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -111,6 +111,12 @@ class QnnModelWrapper { const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } + const ONNX_NAMESPACE::TensorProto* GetInitializerTensor(const std::string& tensor_name) const { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + graph_viewer_.GetInitializedTensor(tensor_name, initializer); + return initializer; + } + bool IsInitializerInput(std::string input_name) const { return initializer_lookup_.find(input_name) != initializer_lookup_.end(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 9fb9e815321c0..cf549148550f1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -18,6 +18,7 @@ #include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" namespace onnxruntime { namespace qnn { @@ -92,6 +93,7 @@ static std::unique_ptr TryQnnFusions( {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, {"Conv", ConvActivationFusion::TryFusion}, {"ConvTranspose", ConvActivationFusion::TryFusion}, + {"Gemm", ReshapeGemmFusion::TryFusion}, }; // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc new file mode 100644 index 0000000000000..bdf06d8fa4a1e --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" + +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/framework/node_unit.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace qnn { + +namespace { + +const NodeUnit* GetReshapeNodeUnit( + const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const Node& gemm_node) { + if (gemm_node.OpType() != "Gemm") { + return nullptr; + } + + for (auto edge = gemm_node.InputEdgesBegin(); edge != gemm_node.InputEdgesEnd(); edge++) { + if (edge->GetDstArgIndex() == 0) { + const Node& reshape_node = edge->GetNode(); + if (reshape_node.OpType() == "Reshape" && !graph_viewer.NodeProducesGraphOutput(reshape_node) && + reshape_node.GetOutputEdgesCount() == 1) { + const auto it = node_to_node_unit.find(&reshape_node); + if (it != node_to_node_unit.end()) { + const NodeUnit* reshape_node_unit = it->second; + if (reshape_node_unit && node_unit_to_qnn_node_group.count(reshape_node_unit) == 0 && + reshape_node_unit->UnitType() == NodeUnit::Type::SingleNode) { + return reshape_node_unit; + } + } + } + } + } + + return nullptr; +} + +// Reshape from [x0, x1, ..., xn, k] to [x0 * x1 * ... * xn, k]. +bool CheckShape(const Node& reshape_node) { + auto input_shape_proto = reshape_node.InputDefs()[0]->Shape(); + auto output_shape_proto = reshape_node.OutputDefs()[0]->Shape(); + if (!input_shape_proto || !output_shape_proto) { + return false; + } + + auto input_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_shape_proto); + auto output_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*output_shape_proto); + auto input_rank = input_shape.NumDimensions(); + auto output_rank = output_shape.NumDimensions(); + return input_shape.Size() != -1 && output_shape.Size() != -1 && output_rank == 2 && + input_shape.SizeToDimension(input_rank - 1) == output_shape[0] && + input_shape[input_rank - 1] == output_shape[1]; +} + +Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& reshape_node_unit, + const NodeUnit& gemm_node_unit, bool validate) { + assert(reshape_node_unit.OpType() == "Reshape" && gemm_node_unit.OpType() == "Gemm"); + const auto& node_name = utils::GetNodeName(gemm_node_unit); + const NodeUnitIODef& input_def = reshape_node_unit.Inputs()[0]; + const NodeUnitIODef& weight_def = gemm_node_unit.Inputs()[1]; + const NodeUnitIODef* bias_def_ptr = nullptr; + bool has_bias = gemm_node_unit.Inputs().size() == 3 && gemm_node_unit.Inputs()[2].node_arg.Exists(); + if (has_bias) { + bias_def_ptr = &gemm_node_unit.Inputs()[2]; + } + const NodeUnitIODef& output_def = gemm_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper bias_tensor; + QnnTensorWrapper output_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + std::vector weight_shape; + std::vector unpacked_tensor; + std::string weight_tensor_name = weight_def.node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(weight_def.node_arg, weight_shape), "Failed to get weight shape"); + Qnn_TensorType_t tensor_type = qnn_model_wrapper.GetTensorType(weight_tensor_name); + Qnn_DataType_t data_type = QNN_DATATYPE_FLOAT_32; + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, weight_def.node_arg.TypeAsProto(), data_type)); + const auto& weight_tensor_proto = qnn_model_wrapper.GetInitializerTensor(weight_tensor_name); + ORT_RETURN_IF_ERROR( + utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor)); + QnnTensorWrapper weight_tensor(weight_tensor_name, tensor_type, data_type, QnnQuantParamsWrapper(), + std::move(weight_shape), std::move(unpacked_tensor)); + if (has_bias) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(*bias_def_ptr, bias_tensor)); + } + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + std::vector input_tensors = {input_tensor.GetQnnTensor(), weight_tensor.GetQnnTensor()}; + if (has_bias) { + input_tensors.emplace_back(bias_tensor.GetQnnTensor()); + } + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_FULLY_CONNECTED, std::move(input_tensors), + {output_tensor.GetQnnTensor()}, {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor)), "Failed to add weight"); + if (has_bias) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(bias_tensor)), "Failed to add bias"); + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + std::vector input_names = {input_def.node_arg.Name(), weight_tensor_name}; + if (has_bias) { + input_names.emplace_back(bias_def_ptr->node_arg.Name()); + } + ORT_RETURN_IF_NOT( + qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_FULLY_CONNECTED, + std::move(input_names), {output_def.node_arg.Name()}, {}, validate), + "Failed to add fused Gemm node."); + } + + return Status::OK(); +} + +} // namespace + +std::unique_ptr ReshapeGemmFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, const NodeUnit& gemm_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& /*logger*/) { + if (gemm_node_unit.OpType() != "Gemm" || gemm_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& gemm_node = gemm_node_unit.GetNode(); + NodeAttrHelper helper(gemm_node); + auto transA = helper.Get("transA", static_cast(0)); + auto transB = helper.Get("transB", static_cast(0)); + const auto& weight_input = gemm_node_unit.Inputs()[1]; + // The pattern is from MatMul->Add, so the transA and transB should be false, and weight should be initializer. + // Currently we don't handle quantized weight. + if (transA != 0 || transB != 0 || !qnn_model_wrapper.IsInitializerInput(weight_input.node_arg.Name()) || + weight_input.quant_param.has_value()) { + return nullptr; + } + + const NodeUnit* reshape_node_unit = + GetReshapeNodeUnit(graph_viewer, node_to_node_unit, node_unit_to_qnn_node_group, gemm_node); + if (!reshape_node_unit) { + return nullptr; + } + + if (!CheckShape(reshape_node_unit->GetNode())) { + return nullptr; + } + + return std::make_unique(*reshape_node_unit, gemm_node_unit); +} + +ReshapeGemmFusion::ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const NodeUnit& gemm_node_unit) + : node_units_{} { + node_units_[0] = &reshape_node_unit; + node_units_[1] = &gemm_node_unit; +} + +Status ReshapeGemmFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], true); +} + +Status ReshapeGemmFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], false); +} + +gsl::span ReshapeGemmFusion::GetNodeUnits() const { + return gsl::make_span(node_units_.data(), 2); +} + +const NodeUnit* ReshapeGemmFusion::GetTargetNodeUnit() const { return node_units_[1]; } + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h new file mode 100644 index 0000000000000..a07e8de5a8e2b --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a Reshape->Gemm sequence to a single Gemm node. +/// Ideally Reshape->Gemm->Reshape should be fused to a single Gemm node with keep_dims set to True, +/// but on some devices the OpConfig validation will fail when keep_dims to True (it says expected value is 0), +/// so we still need to keep the 2nd Reshape node. +/// +class ReshapeGemmFusion : public IQnnNodeGroup { + public: + ReshapeGemmFusion(const NodeUnit& reshape_node_unit, const NodeUnit& gemm_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ReshapeGemmFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ReshapeGemmFusion"; } + + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, const NodeUnit& gemm_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 8d2cb5bdb6da0..67026cf44fa69 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/qnn/builder/qnn_utils.h" + #include #include #include @@ -9,7 +11,8 @@ #include "core/common/common.h" #include "core/framework/data_types.h" -#include "qnn_utils.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/cpu/tensor/transpose.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -570,6 +573,109 @@ Status Quantize(const double double_value, return Status::OK(); } +// NCHW shape to channel last +Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape) { + ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); + nhwc_shape[0] = nchw_shape[0]; + nhwc_shape[1] = nchw_shape[2]; + nhwc_shape[2] = nchw_shape[3]; + nhwc_shape[3] = nchw_shape[1]; + + return Status::OK(); +} + +// NCHW shape to HWCN shape, required for Conv weight +Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape) { + if (nchw_shape.size() == 4) { + hwcn_shape[0] = nchw_shape[2]; + hwcn_shape[1] = nchw_shape[3]; + hwcn_shape[2] = nchw_shape[1]; + hwcn_shape[3] = nchw_shape[0]; + } else if (nchw_shape.size() == 5) { + hwcn_shape[0] = nchw_shape[2]; + hwcn_shape[1] = nchw_shape[3]; + hwcn_shape[2] = nchw_shape[4]; + hwcn_shape[3] = nchw_shape[1]; + hwcn_shape[4] = nchw_shape[0]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); + } + + return Status::OK(); +} + +// CNHW shape to HWCN shape, required for Conv weight +Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape) { + if (cnhw_shape.size() == 4) { + hwcn_shape[0] = cnhw_shape[2]; + hwcn_shape[1] = cnhw_shape[3]; + hwcn_shape[2] = cnhw_shape[0]; + hwcn_shape[3] = cnhw_shape[1]; + } else if (cnhw_shape.size() == 5) { + hwcn_shape[0] = cnhw_shape[2]; + hwcn_shape[1] = cnhw_shape[3]; + hwcn_shape[2] = cnhw_shape[4]; + hwcn_shape[3] = cnhw_shape[0]; + hwcn_shape[4] = cnhw_shape[1]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported rank! only support 4 or 5."); + } + + return Status::OK(); +} + +namespace { +Status TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + const std::vector& perm, std::vector& transposed_data) { + const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType(); + const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer); + TensorShape tensor_shape{tensor_shape_dims}; + AllocatorPtr cpu_allocator = std::make_shared(); + Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator); + + auto rank = perm.size(); + std::vector new_tensor_shape_dims; + std::vector permutations; + new_tensor_shape_dims.reserve(rank); + permutations.reserve(rank); + for (int64_t p : perm) { + permutations.push_back(p); + new_tensor_shape_dims.push_back(tensor_shape_dims[p]); + } + + TensorShape new_tensor_shape(new_tensor_shape_dims); + Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator); + ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor( + Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor)); + ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor)); + onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test"); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data)); + + return Status::OK(); +} +} // namespace + +Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d) { + auto& perm = is_3d ? nchw2hwcn_perm_3d : nchw2hwcn_perm; + return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); +} + +Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d) { + auto& perm = is_3d ? cnhw2hwcn_perm_3d : cnhw2hwcn_perm; + return TransposeInitializer(qnn_model_wrapper, initializer, perm, transposed_data); +} + +Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, std::vector& data_shape, + const onnx::TensorProto& initializer, std::vector& transposed_data) { + auto tmp = data_shape[0]; + data_shape[0] = data_shape[1]; + data_shape[1] = tmp; + std::vector two_dim_trans_perm{1, 0}; + return TransposeInitializer(qnn_model_wrapper, initializer, two_dim_trans_perm, transposed_data); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index aa4a27460563f..76743f4b8c69b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -12,6 +12,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/node_unit.h" #include "core/util/qmath.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" namespace onnxruntime { namespace qnn { @@ -104,6 +105,24 @@ Status Quantize(const double double_value, const Qnn_DataType_t qnn_data_type, int& quant_value); +// NCHW shape to channel last +Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape); + +// NCHW shape to HWCN shape, required for Conv weight +Status NchwShapeToHwcn(const std::vector& nchw_shape, std::vector& hwcn_shape); + +// CNHW shape to HWCN shape, required for Conv weight +Status CnhwShapeToHwcn(const std::vector& cnhw_shape, std::vector& hwcn_shape); + +Status TransposeFromNchwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d = false); + +Status TransposeFromCnhwToHwcn(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, + std::vector& transposed_data, bool is_3d = false); + +Status TwoDimensionTranspose(const QnnModelWrapper& qnn_model_wrapper, std::vector& data_shape, + const onnx::TensorProto& initializer, std::vector& transposed_data); + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 67d60ea3a4ff6..4ffe6bcf5ef15 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2443,12 +2443,8 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_three_input) { ASSERT_TRUE(op_to_count["Gemm"] == 1); } -// Matmul+Add with shape [k]*[k,N]+[N], won't do the fusion -// We can do the fusion by changing shape to [1,k]*[k,N]+[1,N], then add a reshape [1,N]=>[N] -// This will bring extra cost. And there's only very limited gain to fuse Matmul+Add to Gemm -// Since the basic implementation is almost same -TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/neg_model.onnx"; +TEST_F(GraphTransformationTests, MatMulAddFusion_three_input_with_1d) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/model_1d.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -2459,9 +2455,22 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["MatMul"] == 1); - ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_TRUE(op_to_count["Gemm"] == 0); + ASSERT_TRUE(op_to_count["MatMul"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Gemm"] == 1); + ASSERT_TRUE(op_to_count["Reshape"] == 2); + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Reshape") { + auto shape_proto = node.OutputDefs()[0]->Shape(); + ASSERT_TRUE(shape_proto != nullptr); + auto shape = utils::GetTensorShapeFromTensorShapeProto(*shape_proto); + if (node.Name().find("gemm_input") != std::string::npos) { + ASSERT_TRUE(shape.NumDimensions() == 2 && shape[0] == 1 && shape[1] == 4); + } else { + ASSERT_TRUE(shape.NumDimensions() == 1 && shape[0] == 3); + } + } + } } // Matmul+Add with shape [M,k]*[k,N]+[1,4], won't do the fusion @@ -2500,6 +2509,50 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_MissingShape) { ASSERT_EQ(op_to_count["Gemm"], 0); } +TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_3D) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{8, 16, 32}}); + auto* weight_arg = builder.MakeInput({{32, 768}}); + auto* bias_arg = builder.MakeInput({{768}}); + auto* matmul_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + builder.AddNode("MatMul", {input_arg, weight_arg}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias_arg}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 2); + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "Reshape") { + auto shape_proto = node.OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(shape_proto != nullptr); + auto shape = utils::GetTensorShapeFromTensorShapeProto(*shape_proto); + if (node.Name().find("gemm_input") != std::string::npos) { + TEST_RETURN_IF_NOT(shape.NumDimensions() == 2 && shape[0] == 8 * 16 && shape[1] == 32); + } else { + TEST_RETURN_IF_NOT(shape.NumDimensions() == 3 && shape[0] == 8 && shape[1] == 16 && shape[2] == 768); + } + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + #ifndef DISABLE_CONTRIB_OPS TEST_F(GraphTransformationTests, Gemm_Relu_three_input) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx"; @@ -4080,6 +4133,41 @@ TEST_F(GraphTransformationTests, ReshapeFusionDistilBertTest) { } } +TEST_F(GraphTransformationTests, ReshapeFusion_Contiguous_Reshape) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({{8, 16, 32}}); + auto* shape_initializer_1 = builder.MakeInitializer({4}, {2, 4, 16, 32}); + auto* shape_initializer_2 = builder.MakeInitializer({4}, {2, 64, 32}); + auto* axes_initializer = builder.MakeInitializer({1}, {1}); + auto* reshape_out_1 = builder.MakeIntermediate(); + auto* reshape_out_2 = builder.MakeIntermediate(); + auto* unsqueeze_out = builder.MakeIntermediate(); + auto* output_arg = builder.MakeOutput(); + builder.AddNode("Reshape", {input_arg, shape_initializer_1}, {reshape_out_1}); + builder.AddNode("Reshape", {reshape_out_1, shape_initializer_2}, {reshape_out_2}); + builder.AddNode("Unsqueeze", {reshape_out_2, axes_initializer}, {unsqueeze_out}); + builder.AddNode("Identity", {unsqueeze_out}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + std::map op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Unsqueeze"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + // Test eliminating redundant Concat-Slice pattern. TEST_F(GraphTransformationTests, ConcatSliceEliminationTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "concat_slice_basic_test.onnx"; diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index 33c868694c9c0..6ebc02f8ad803 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -177,6 +177,47 @@ TEST_F(QnnCPUBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { ExpectedEPNodeAssignment::All); } +namespace { +GetTestModelFn BuildReshapeGemmTestCase(const TestInputDef& input, const TestInputDef& shape, + const TestInputDef& weight, const TestInputDef& bias) { + return [&](ModelTestBuilder& builder) { + std::vector reshape_inputs = {MakeTestInput(builder, input), + MakeTestInput(builder, shape)}; + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", reshape_inputs, {reshape_output}); + NodeArg* output = builder.MakeOutput(); + std::vector gemm_inputs = {reshape_output, MakeTestInput(builder, weight), + MakeTestInput(builder, bias)}; + builder.AddNode("Gemm", gemm_inputs, {output}); + }; +} + +void RunCPUReshapeGemmTest(const TestInputDef& input, const TestInputDef& shape, + const TestInputDef& weight, const TestInputDef& bias, + ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + auto build_fn = BuildReshapeGemmTestCase(input, shape, weight, bias); + RunQnnModelTest(build_fn, provider_options, 18, expected_ep_assignment, fp32_abs_err); +} + +} // namespace + +TEST_F(QnnCPUBackendTests, ReshapeGemmFusion) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector shape_data = {4, 2}; + std::vector weight_data(6, 1.0f); + std::vector bias_data = {1.0f, 2.0f, 3.0f}; + RunCPUReshapeGemmTest(TestInputDef({2, 2, 2}, false, input_data), TestInputDef({2}, true, shape_data), + TestInputDef({2, 3}, true, weight_data), TestInputDef({3}, true, bias_data), + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index e8282dbad9f72..ebb6b57f82d07 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -881,10 +881,10 @@ TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { "high"); // qnn_context_priority } -// Create a model with Case + Add (quantized) -// cast_input -> Cast -> Q -> DQ \ -// Add -> Q -> DQ -> output -// input2 -> Q -> DQ / +// Create a model with Cast + Add (quantized) +// cast_input -> Cast -> Q -> DQ ---- +// | +// input2 -> Q -> DQ -> Add -> Q -> DQ -> output static GetTestModelFn BuildCastAddTestCase() { return [](ModelTestBuilder& builder) { // Creat Cast node int32 -> float32 diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a3f0ed55b83f2..d3313b4c9b1c7 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -26,9 +26,9 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Create a model with FusedMatMul + Add (quantized) -// input1 -> Add -> Q -> DQ \ -// FusedMatMul -> Q -> DQ -> output -// input2 -> Q -> DQ / +// input1 -> Add -> Q -> DQ ---- +// | +// input2 -> Q -> DQ -> FusedMatMul -> Q -> DQ -> output static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { return [single_ep_node](ModelTestBuilder& builder) { // Creat non-quantized FusedMatMul node1 @@ -162,10 +162,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { QnnContextBinaryMultiPartitionTestBody(single_ep_node); } -// Create a model with Case + Add (quantized) -// cast_input -> Cast -> Q -> DQ \ -// Add -> Q -> DQ -> output -// input2 -> Q -> DQ / +// Create a model with Cast + Add (quantized) +// cast_input -> Cast -> Q -> DQ ---- +// | +// input2 -> Q -> DQ -> Add -> Q -> DQ -> output static GetTestModelFn BuildCastAddTestCase() { return [](ModelTestBuilder& builder) { // Creat Cast node int32 -> float32 @@ -194,9 +194,9 @@ static GetTestModelFn BuildCastAddTestCase() { } // Create a model with Add (quantized) -// input1 -> Q -> DQ \ -// Add -> Q -> DQ -> output -// input2 -> Q -> DQ / +// input1 -> Q -> DQ ---- +// | +// input2 -> Q -> DQ -> Add -> Q -> DQ -> output static GetTestModelFn BuildAddTestCase() { return [](ModelTestBuilder& builder) { std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/neg_model.onnx b/onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model_1d.onnx similarity index 100% rename from onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/neg_model.onnx rename to onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model_1d.onnx