Skip to content

Commit

Permalink
Dont always return Status
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Jul 27, 2024
1 parent e2c9c00 commit 6f042c8
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 126 deletions.
153 changes: 88 additions & 65 deletions onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer,

// Child must be of a valid type.
const Node& child_node = parent_node.OutputEdgesBegin()->GetNode();
if (graph_viewer.GetNode(child_node.Index()) == nullptr) {
return nullptr; // Node is not in this GraphViewer
}
const std::string& child_type = child_node.OpType();
bool is_valid_child_type = false;

Expand All @@ -44,7 +47,9 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer,
}

const auto child_node_unit_it = node_unit_map.find(&child_node);
assert(child_node_unit_it != node_unit_map.end());
if (child_node_unit_it == node_unit_map.end()) {
return nullptr;
}
const NodeUnit* child_node_unit = child_node_unit_it->second;

// Check if child node has already been handled. Should not be the case if the calling
Expand Down Expand Up @@ -290,64 +295,84 @@ static std::vector<const Node*> FindQDQNodes(const GraphViewer& graph_viewer, co
return nodes;
}

static Status GetConvDQNodeUnits(
/*out*/ std::vector<const NodeUnit*>& dq_node_units,
static std::vector<const NodeUnit*> GetConvDQs(
const GraphViewer& graph_viewer,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, QnnNodeGroup::IndexType>& node_unit_to_qnn_node_group,
const Node& conv_node,
const Node& q_node) {
assert((conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose") &&
q_node.OpType() == QDQ::QOpName);
const Node& conv_node) {
assert(conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose");
std::vector<const Node*> dq_nodes = FindQDQNodes(graph_viewer, conv_node, /*find_dq_nodes*/ true);
std::vector<const Node*> q_nodes = {&q_node};
int num_dq_inputs = NumActualValues(conv_node, /*input*/ true);

// Within a QDQ node group, a target node input is the only consumer of each DQ.
ORT_RETURN_IF_NOT(num_dq_inputs == static_cast<int>(dq_nodes.size()),
"Conv should be the only consumer of each DQ");
if (num_dq_inputs != static_cast<int>(dq_nodes.size())) {
return {};
}

std::vector<const NodeUnit*> dq_node_units;
for (const auto* dq_node : dq_nodes) {
ORT_RETURN_IF(graph_viewer.NodeProducesGraphOutput(*dq_node),
"QDQ ", conv_node.OpType(), "'s input DQ node must not produce a graph output");
if (graph_viewer.NodeProducesGraphOutput(*dq_node)) {
return {};
}

const bool dq_has_single_output_edge_to_target =
dq_node->GetOutputEdgesCount() == 1 &&
dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index();
ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, "DQ should have a single output to Conv");
if (!dq_has_single_output_edge_to_target) {
return {};
}

const auto it = node_to_node_unit.find(dq_node);
if (it == node_to_node_unit.end()) {
return {};
}

const NodeUnit* dq_node_unit = it->second;

if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) {
return {};
}

if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) {
return {};
}

dq_node_units.push_back(dq_node_unit);
}

// input and output types need to be same
int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_weight = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(dt_input != dt_output, "Conv input[0] and output quantization types must match");
return dq_node_units;
}

if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
ORT_RETURN_IF(dt_weight != dt_input,
conv_node.OpType(), "'s input[0] and input[1] quantization types must match if input[0] is int8");
static bool IsValidQDQConv(gsl::span<const NodeUnit*> dq_node_units,
gsl::not_null<const NodeUnit*> q_node_unit) {
assert(q_node_unit->OpType() == QDQ::QOpName);
const size_t num_dqs = dq_node_units.size();
if (num_dqs != 2 && num_dqs != 3) {
return false;
}

if (dq_nodes.size() == 3) { // has bias
int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
"QDQ ", conv_node.OpType(), " must have int32 quantized bias");
// input and output types need to be same
int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (dt_input != dt_output) {
return false;
}

dq_node_units.reserve(dq_nodes.size());
for (const auto* dq_node : dq_nodes) {
const auto it = node_to_node_unit.find(dq_node);
assert(it != node_to_node_unit.end());
const NodeUnit* dq_node_unit = it->second;
if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if (dt_weight != dt_input) {
return false;
}
}

ORT_RETURN_IF_NOT(node_unit_to_qnn_node_group.count(dq_node_unit) == 0,
"DQ NodeUnit ", dq_node_unit->Name(), " has already been added to another QnnNodeGroup");
ORT_RETURN_IF_NOT(dq_node_unit->UnitType() == NodeUnit::Type::SingleNode,
"Expect DQ to be a NodeUnit of type SingleNode");
dq_node_units.push_back(dq_node_unit);
if (num_dqs == 3) { // has bias
int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) {
return false;
}
}

return Status::OK();
return true;
}

Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper,
Expand Down Expand Up @@ -378,12 +403,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper,
return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate);
}

Status TryConvActivationFusion(/*out*/ std::optional<QnnNodeGroup>& qnn_node_group,
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& conv_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, QnnNodeGroup::IndexType>& node_unit_to_qnn_node_group,
const logging::Logger& logger) {
std::optional<QnnNodeGroup> TryConvActivationFusion(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& conv_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, QnnNodeGroup::IndexType>& node_unit_to_qnn_node_group,
const logging::Logger& logger) {
// Expect that this function is called with a standalone Conv or ConvTranspose.
assert((conv_node_unit.OpType() == "Conv" || conv_node_unit.OpType() == "ConvTranspose") &&
conv_node_unit.UnitType() == NodeUnit::Type::SingleNode);
Expand All @@ -395,7 +419,7 @@ Status TryConvActivationFusion(/*out*/ std::optional<QnnNodeGroup>& qnn_node_gro
const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types,
node_to_node_unit, node_unit_to_qnn_node_group);
if (activation_node_unit == nullptr) {
return Status::OK();
return std::nullopt;
}

// Relu/Clip must have a single Q child.
Expand All @@ -404,28 +428,25 @@ Status TryConvActivationFusion(/*out*/ std::optional<QnnNodeGroup>& qnn_node_gro
node_to_node_unit, node_unit_to_qnn_node_group);

if (q_node_unit == nullptr) {
return Status::OK();
return std::nullopt;
}

// Check if Clip/Relu can be removed because the Q node provides an equivalent effect.
if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit)) {
return Status::OK();
return std::nullopt;
}

// Create a QDQ node group with DQ* -> Conv -> Q
const Node& conv_node = conv_node_unit.GetNode();
const Node& activation_node = activation_node_unit->GetNode();
const Node& q_node = q_node_unit->GetNode();
std::vector<const NodeUnit*> dq_node_units;
QNN_RETURN_OK_IF_ERROR(GetConvDQNodeUnits(dq_node_units,
graph_viewer,
node_to_node_unit,
node_unit_to_qnn_node_group,
conv_node,
q_node),
logger);
std::vector<const NodeUnit*> dq_node_units = GetConvDQs(graph_viewer,
node_to_node_unit,
node_unit_to_qnn_node_group,
conv_node);

assert(dq_node_units.size() == 3 || dq_node_units.size() == 2);
if (!IsValidQDQConv(dq_node_units, q_node_unit)) {
return std::nullopt;
}

// Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before
// modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this
Expand All @@ -439,14 +460,16 @@ Status TryConvActivationFusion(/*out*/ std::optional<QnnNodeGroup>& qnn_node_gro
qnn_model_wrapper.GetInitializerLookup(),
qnn_model_wrapper.GetQnnBackendType());

QNN_RETURN_OK_IF_ERROR(QnnConvActivationFusionAdd(tmp_model_wrapper,
dq_node_units,
&conv_node_unit,
activation_node_unit,
q_node_unit,
logger,
/*validate*/ true),
logger);
if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper,
dq_node_units,
&conv_node_unit,
activation_node_unit,
q_node_unit,
logger,
/*validate*/ true);
!status.IsOK()) {
return std::nullopt;
}

// Validation passed, so create a QnnNodeGroup.
// If we encounter an error, we return it directly to caller.
Expand All @@ -455,14 +478,14 @@ Status TryConvActivationFusion(/*out*/ std::optional<QnnNodeGroup>& qnn_node_gro
<< "] activation_node name: [" << activation_node.Name()
<< "]";

qnn_node_group = QnnNodeGroup{};
std::optional<QnnNodeGroup> qnn_node_group = QnnNodeGroup{};
qnn_node_group->type_ = QnnNodeGroup::Type::ConvActivationFusion;
qnn_node_group->node_units_ = std::move(dq_node_units);
qnn_node_group->node_units_.push_back(&conv_node_unit);
qnn_node_group->node_units_.push_back(activation_node_unit);
qnn_node_group->node_units_.push_back(q_node_unit);

return Status::OK();
return qnn_node_group;
}
} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper,
const logging::Logger& logger,
bool validate = false);

Status TryConvActivationFusion(/*out*/ std::optional<QnnNodeGroup>& qnn_node_group,
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& conv_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, QnnNodeGroup::IndexType>& node_unit_to_qnn_node_group,
const logging::Logger& logger);
std::optional<QnnNodeGroup> TryConvActivationFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& conv_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, QnnNodeGroup::IndexType>& node_unit_to_qnn_node_group,
const logging::Logger& logger);
} // namespace qnn
} // namespace onnxruntime
Loading

0 comments on commit 6f042c8

Please sign in to comment.