From 6f042c86e1b367f04ec8fe698c37f8266086d006 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 00:46:22 -0700 Subject: [PATCH] Dont always return Status --- .../qnn/builder/qnn_conv_activation_fusion.cc | 153 ++++++++++-------- .../qnn/builder/qnn_conv_activation_fusion.h | 12 +- .../core/providers/qnn/builder/qnn_fusions.cc | 127 ++++++++------- 3 files changed, 166 insertions(+), 126 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index 0aba738627e57..b62c5f21f82ba 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -29,6 +29,9 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, // Child must be of a valid type. const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + return nullptr; // Node is not in this GraphViewer + } const std::string& child_type = child_node.OpType(); bool is_valid_child_type = false; @@ -44,7 +47,9 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, } const auto child_node_unit_it = node_unit_map.find(&child_node); - assert(child_node_unit_it != node_unit_map.end()); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } const NodeUnit* child_node_unit = child_node_unit_it->second; // Check if child node has already been handled. Should not be the case if the calling @@ -290,64 +295,84 @@ static std::vector FindQDQNodes(const GraphViewer& graph_viewer, co return nodes; } -static Status GetConvDQNodeUnits( - /*out*/ std::vector& dq_node_units, +static std::vector GetConvDQs( const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, - const Node& conv_node, - const Node& q_node) { - assert((conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose") && - q_node.OpType() == QDQ::QOpName); + const Node& conv_node) { + assert(conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose"); std::vector dq_nodes = FindQDQNodes(graph_viewer, conv_node, /*find_dq_nodes*/ true); - std::vector q_nodes = {&q_node}; int num_dq_inputs = NumActualValues(conv_node, /*input*/ true); // Within a QDQ node group, a target node input is the only consumer of each DQ. - ORT_RETURN_IF_NOT(num_dq_inputs == static_cast(dq_nodes.size()), - "Conv should be the only consumer of each DQ"); + if (num_dq_inputs != static_cast(dq_nodes.size())) { + return {}; + } + std::vector dq_node_units; for (const auto* dq_node : dq_nodes) { - ORT_RETURN_IF(graph_viewer.NodeProducesGraphOutput(*dq_node), - "QDQ ", conv_node.OpType(), "'s input DQ node must not produce a graph output"); + if (graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return {}; + } const bool dq_has_single_output_edge_to_target = dq_node->GetOutputEdgesCount() == 1 && dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); - ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, "DQ should have a single output to Conv"); + if (!dq_has_single_output_edge_to_target) { + return {}; + } + + const auto it = node_to_node_unit.find(dq_node); + if (it == node_to_node_unit.end()) { + return {}; + } + + const NodeUnit* dq_node_unit = it->second; + + if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { + return {}; + } + + if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + dq_node_units.push_back(dq_node_unit); } - // input and output types need to be same - int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_weight = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(dt_input != dt_output, "Conv input[0] and output quantization types must match"); + return dq_node_units; +} - if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { - ORT_RETURN_IF(dt_weight != dt_input, - conv_node.OpType(), "'s input[0] and input[1] quantization types must match if input[0] is int8"); +static bool IsValidQDQConv(gsl::span dq_node_units, + gsl::not_null q_node_unit) { + assert(q_node_unit->OpType() == QDQ::QOpName); + const size_t num_dqs = dq_node_units.size(); + if (num_dqs != 2 && num_dqs != 3) { + return false; } - if (dq_nodes.size() == 3) { // has bias - int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, - "QDQ ", conv_node.OpType(), " must have int32 quantized bias"); + // input and output types need to be same + int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return false; } - dq_node_units.reserve(dq_nodes.size()); - for (const auto* dq_node : dq_nodes) { - const auto it = node_to_node_unit.find(dq_node); - assert(it != node_to_node_unit.end()); - const NodeUnit* dq_node_unit = it->second; + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (dt_weight != dt_input) { + return false; + } + } - ORT_RETURN_IF_NOT(node_unit_to_qnn_node_group.count(dq_node_unit) == 0, - "DQ NodeUnit ", dq_node_unit->Name(), " has already been added to another QnnNodeGroup"); - ORT_RETURN_IF_NOT(dq_node_unit->UnitType() == NodeUnit::Type::SingleNode, - "Expect DQ to be a NodeUnit of type SingleNode"); - dq_node_units.push_back(dq_node_unit); + if (num_dqs == 3) { // has bias + int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } } - return Status::OK(); + return true; } Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, @@ -378,12 +403,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); } -Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone Conv or ConvTranspose. assert((conv_node_unit.OpType() == "Conv" || conv_node_unit.OpType() == "ConvTranspose") && conv_node_unit.UnitType() == NodeUnit::Type::SingleNode); @@ -395,7 +419,7 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, node_to_node_unit, node_unit_to_qnn_node_group); if (activation_node_unit == nullptr) { - return Status::OK(); + return std::nullopt; } // Relu/Clip must have a single Q child. @@ -404,28 +428,25 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro node_to_node_unit, node_unit_to_qnn_node_group); if (q_node_unit == nullptr) { - return Status::OK(); + return std::nullopt; } // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit)) { - return Status::OK(); + return std::nullopt; } // Create a QDQ node group with DQ* -> Conv -> Q const Node& conv_node = conv_node_unit.GetNode(); const Node& activation_node = activation_node_unit->GetNode(); - const Node& q_node = q_node_unit->GetNode(); - std::vector dq_node_units; - QNN_RETURN_OK_IF_ERROR(GetConvDQNodeUnits(dq_node_units, - graph_viewer, - node_to_node_unit, - node_unit_to_qnn_node_group, - conv_node, - q_node), - logger); + std::vector dq_node_units = GetConvDQs(graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node); - assert(dq_node_units.size() == 3 || dq_node_units.size() == 2); + if (!IsValidQDQConv(dq_node_units, q_node_unit)) { + return std::nullopt; + } // Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before // modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this @@ -439,14 +460,16 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro qnn_model_wrapper.GetInitializerLookup(), qnn_model_wrapper.GetQnnBackendType()); - QNN_RETURN_OK_IF_ERROR(QnnConvActivationFusionAdd(tmp_model_wrapper, - dq_node_units, - &conv_node_unit, - activation_node_unit, - q_node_unit, - logger, - /*validate*/ true), - logger); + if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper, + dq_node_units, + &conv_node_unit, + activation_node_unit, + q_node_unit, + logger, + /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } // Validation passed, so create a QnnNodeGroup. // If we encounter an error, we return it directly to caller. @@ -455,14 +478,14 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro << "] activation_node name: [" << activation_node.Name() << "]"; - qnn_node_group = QnnNodeGroup{}; + std::optional qnn_node_group = QnnNodeGroup{}; qnn_node_group->type_ = QnnNodeGroup::Type::ConvActivationFusion; qnn_node_group->node_units_ = std::move(dq_node_units); qnn_node_group->node_units_.push_back(&conv_node_unit); qnn_node_group->node_units_.push_back(activation_node_unit); qnn_node_group->node_units_.push_back(q_node_unit); - return Status::OK(); + return qnn_node_group; } } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h index 9cca16536ad95..3f5bdfc3078dc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -23,11 +23,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger, bool validate = false); -Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger); +std::optional TryConvActivationFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index 358292cb3cc17..04b727f534424 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -308,12 +308,12 @@ const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) c * \param do_op_validation True if should call QNN operator validation APIs. * \return An onnxruntime::Status */ -static Status TryDQQFusion(std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +static std::optional TryDQQFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone DQ. assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); @@ -322,27 +322,33 @@ static Status TryDQQFusion(std::optional& qnn_node_group, // DQ must have a single child (1 output edge) and must not produce a graph output. if (dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return Status::OK(); + return std::nullopt; } const Node& q_node = dq_node.OutputEdgesBegin()->GetNode(); if (q_node.OpType() != QDQ::QOpName) { - return Status::OK(); + return std::nullopt; + } + + if (graph_viewer.GetNode(q_node.Index()) == nullptr) { + return std::nullopt; // Node is not in this GraphViewer } const auto q_node_unit_it = node_to_node_unit.find(&q_node); - ORT_RETURN_IF(q_node_unit_it == node_to_node_unit.end(), "Node does not have a corresponding NodeUnit"); + if (q_node_unit_it == node_to_node_unit.end()) { + return std::nullopt; + } const NodeUnit* q_node_unit = q_node_unit_it->second; // child must not already be part of a QDQ NodeUnit (i.e., be standalone). if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } // Check if child node has already been handled. Should not be the case if this // fusion function has been called in topological order, but check to be safe. if (node_unit_to_qnn_node_group.count(q_node_unit) != 0) { - return Status::OK(); + return std::nullopt; } auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { @@ -351,11 +357,14 @@ static Status TryDQQFusion(std::optional& qnn_node_group, // DQ and Q must have equal scale type and different zp type. if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return Status::OK(); + return std::nullopt; } - QNN_RETURN_OK_IF_ERROR(QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, logger, /*validate*/ true), - logger); + if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, + logger, /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } // Validation passed, so create a QnnNodeGroup. LOGS(logger, VERBOSE) << " Will use QNN Convert via fusion. dq_node name: [" << dq_node.Name() @@ -364,12 +373,12 @@ static Status TryDQQFusion(std::optional& qnn_node_group, << "] q_node optype: [" << q_node_unit->OpType() << "]"; - qnn_node_group = QnnNodeGroup{}; + std::optional qnn_node_group = QnnNodeGroup{}; qnn_node_group->type_ = QnnNodeGroup::Type::DQQFusion; qnn_node_group->node_units_.push_back(&dq_node_unit); qnn_node_group->node_units_.push_back(q_node_unit); - return Status::OK(); + return qnn_node_group; } /** @@ -386,16 +395,16 @@ static Status TryDQQFusion(std::optional& qnn_node_group, * \param do_op_validation True if should call QNN operator validation APIs. * \return A Status indicating a potential failure. */ -static Status TryHardSigmoidMulFusion(std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& hardsigmoid_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +static std::optional TryHardSigmoidMulFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Looking for a standalone HardSigmoid to start the sequence. if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); @@ -408,7 +417,7 @@ static Status TryHardSigmoidMulFusion(std::optional& qnn_node_grou // Check for explicit values of alpha and beta. if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return Status::OK(); + return std::nullopt; } const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); @@ -416,27 +425,33 @@ static Status TryHardSigmoidMulFusion(std::optional& qnn_node_grou // HardSigmoid must have a single child (1 output edge) and must not produce a graph output. if (hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return Status::OK(); + return std::nullopt; } const Node& mul_node = hs_node.OutputEdgesBegin()->GetNode(); if (mul_node.OpType() != "Mul") { - return Status::OK(); + return std::nullopt; + } + + if (graph_viewer.GetNode(mul_node.Index()) == nullptr) { + return std::nullopt; // Node is not in this GraphViewer } const auto mul_node_unit_it = node_to_node_unit.find(&mul_node); - ORT_RETURN_IF(mul_node_unit_it == node_to_node_unit.end(), "Mul Node does not have a corresponding NodeUnit"); + if (mul_node_unit_it == node_to_node_unit.end()) { + return std::nullopt; + } const NodeUnit* mul_node_unit = mul_node_unit_it->second; // Check if Mul node has already been handled. Should not be the case if this // fusion function has been called in topological order, but check to be safe. if (node_unit_to_qnn_node_group.count(mul_node_unit) != 0) { - return Status::OK(); + return std::nullopt; } // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } // Input to HardSigmoid must also be the other input to the Mul. @@ -445,38 +460,40 @@ static Status TryHardSigmoidMulFusion(std::optional& qnn_node_grou mul_node.InputDefs()[1]->Name() == hs_input_name; if (!same_root_input) { - return Status::OK(); + return std::nullopt; } - QNN_RETURN_OK_IF_ERROR(QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, - logger, /*validate*/ true), - logger); + if (Status status = QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, + logger, /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } // Validation passed, so create a QnnNodeGroup. Any errors are now passed back to the caller. LOGS(logger, VERBOSE) << "Will use QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() << "] Mul name: [" << mul_node_unit->Name() << "]"; - qnn_node_group = QnnNodeGroup{}; + std::optional qnn_node_group = QnnNodeGroup{}; qnn_node_group->type_ = QnnNodeGroup::Type::HardSigmoidMulFusion; qnn_node_group->node_units_.push_back(&hardsigmoid_node_unit); qnn_node_group->node_units_.push_back(mul_node_unit); - return Status::OK(); + return qnn_node_group; } -using FusionFunc = Status (*)(std::optional&, - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_map&, - const logging::Logger&); - -static Status TryQnnFusions(/*out*/ std::optional& fused_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +using FusionFunc = std::optional (*)( + QnnModelWrapper&, + const NodeUnit&, + const std::unordered_map&, + const std::unordered_map&, + const logging::Logger&); + +static std::optional TryQnnFusions( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Maps a starting operator type to the fusion function. static std::unordered_map fusions = { {"DequantizeLinear", TryDQQFusion}, @@ -487,16 +504,16 @@ static Status TryQnnFusions(/*out*/ std::optional& fused_node_grou // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } auto iter = fusions.find(starting_node_unit.OpType()); if (iter != fusions.end()) { FusionFunc fusion_func = iter->second; - ORT_RETURN_IF_ERROR(fusion_func(fused_node_group, qnn_model_wrapper, starting_node_unit, node_to_node_unit, - node_unit_to_qnn_node_group, logger)); + return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger); } - return Status::OK(); + return std::nullopt; } Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, @@ -538,9 +555,9 @@ Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, continue; // Already handled this node unit } - std::optional fused_node_group; - ORT_RETURN_IF_ERROR(TryQnnFusions(fused_node_group, qnn_model_wrapper, *node_unit, - node_to_node_unit, node_unit_to_qnn_node_group, logger)); + std::optional fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); if (fused_node_group.has_value()) { const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size();