diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 54964b0275fc8..a491edb9699b3 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -272,41 +272,18 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } -NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group, - const Node& output_activation_node) - : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, - target_node_(*graph_viewer.GetNode(node_group.target_node)), - q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, - type_(Type::QDQGroup), - inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, - outputs_{GetQDQIODefs(output_activation_node, node_group, false /* is_input */)} { - input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), - [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); - - // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. - // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). - input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); - - // create output edges. each target node output either goes to Q node/s or non-Q node/s. - // ValidateNodeGroupQDQNodes ensures this. - auto cur_edge = output_activation_node.OutputEdgesBegin(); - auto end_edge = output_activation_node.OutputEdgesEnd(); - for (; cur_edge != end_edge; ++cur_edge) { - const Node& node = cur_edge->GetNode(); - - // if node is in q_nodes we hide the Q node. - if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { - auto src_idx = cur_edge->GetSrcArgIndex(); - auto q_cur_edge = node.OutputEdgesBegin(); - auto q_end_edge = node.OutputEdgesEnd(); - for (; q_cur_edge != q_end_edge; ++q_cur_edge) { - output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); - } - } else { - // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. - output_edges_.insert(*cur_edge); - } - } +NodeUnit::NodeUnit(std::vector dq_nodes, const Node& target_node, + std::vector q_nodes, Type type, + std::vector inputs, std::vector outputs, + size_t input_edge_count, Node::EdgeSet output_edges) + : dq_nodes_(std::move(dq_nodes)), + target_node_(target_node), + q_nodes_(std::move(q_nodes)), + type_(type), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + input_edge_count_(input_edge_count), + output_edges_(std::move(output_edges)) { } const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index 494d7bd849b4b..060653170fde0 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,8 +68,10 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); - explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group, - const Node& output_activation_node); + NodeUnit(std::vector dq_nodes, const Node& target_node, + std::vector q_nodes, Type type, + std::vector inputs, std::vector outputs, + size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index b62c5f21f82ba..e5891baf4ac50 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -378,19 +378,60 @@ static bool IsValidQDQConv(gsl::span dq_node_units, Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, gsl::span dq_node_units, const NodeUnit* conv_node_unit, - const NodeUnit* activation_node_unit, const NodeUnit* q_node_unit, const logging::Logger& logger, bool validate) { - QDQ::NodeGroup custom_node_group; - custom_node_group.dq_nodes.reserve(dq_node_units.size()); - custom_node_group.q_nodes = std::vector{q_node_unit->Index()}; - custom_node_group.target_node = conv_node_unit->Index(); - auto get_node_idx = [](const NodeUnit* n) { return n->Index(); }; - std::transform(dq_node_units.begin(), dq_node_units.end(), std::back_inserter(custom_node_group.dq_nodes), - get_node_idx); - - NodeUnit custom_node_unit(qnn_model_wrapper.GetGraphViewer(), custom_node_group, activation_node_unit->GetNode()); + std::vector dq_nodes; + dq_nodes.reserve(dq_node_units.size()); + for (const NodeUnit* dq_node_unit : dq_node_units) { + dq_nodes.push_back(&dq_node_unit->GetNode()); + } + std::vector q_nodes = {&q_node_unit->GetNode()}; + const Node& target_node = conv_node_unit->GetNode(); + + // Populate NodeUnit inputs + std::vector inputs; + inputs.reserve(dq_node_units.size()); + for (const Node* dq_node : dq_nodes) { + const auto dq_inputs = dq_node->InputDefs(); + const auto& dq_attrs = dq_node->GetAttributes(); + + std::optional axis; + if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; + inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param}); + } + + // Populate NodeUnit outputs and output edges + std::vector outputs; + Node::EdgeSet output_edges; + for (const Node* q_node : q_nodes) { + const auto q_inputs = q_node->InputDefs(); + const auto& q_attrs = q_node->GetAttributes(); + const auto q_outputs = q_node->OutputDefs(); + + std::optional axis; + if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; + outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); + + auto q_cur_edge = q_node->OutputEdgesBegin(); + auto q_end_edge = q_node->OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); + } + } + + NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, + inputs, outputs, dq_nodes.size(), output_edges); const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); if (conv_op_builder == nullptr) { return Status::OK(); @@ -463,7 +504,6 @@ std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_w if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper, dq_node_units, &conv_node_unit, - activation_node_unit, q_node_unit, logger, /*validate*/ true); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h index 3f5bdfc3078dc..b019dbb9205d6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -18,7 +18,6 @@ namespace qnn { Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, gsl::span dq_node_units, const NodeUnit* conv_node_unit, - const NodeUnit* activation_node_unit, const NodeUnit* q_node_unit, const logging::Logger& logger, bool validate = false); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index 04b727f534424..033ac3ce4a555 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -155,7 +155,6 @@ Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& lo Status status = QnnConvActivationFusionAdd(qmw, dq_node_units, conv_node_unit, - activation_node_unit, q_node_unit, logger, /*validate*/ true); @@ -225,7 +224,6 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg const bool has_bias_dq = num_node_units == 6; std::vector dq_node_units = {node_units_[0], node_units_[1]}; const NodeUnit* conv_node_unit = node_units_[num_node_units - 3]; - const NodeUnit* activation_node_unit = node_units_[num_node_units - 2]; const NodeUnit* q_node_unit = node_units_[num_node_units - 1]; if (has_bias_dq) { @@ -234,7 +232,6 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg return QnnConvActivationFusionAdd(qmw, dq_node_units, conv_node_unit, - activation_node_unit, q_node_unit, logger, /*validate*/ false); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 37eeac5101feb..5146a1cc14865 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -953,7 +953,7 @@ TEST_F(QnnHTPBackendTests, TestOD) { #if 1 const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx"; - //so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); #else const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx"; #endif