Skip to content

Commit

Permalink
Remove use of optimizer qdq utils from fusion code; Rename fusion cla…
Browse files Browse the repository at this point in the history
…sses
  • Loading branch information
adrianlizarraga committed Jul 28, 2024
1 parent 576c2f8 commit 0e7eece
Show file tree
Hide file tree
Showing 12 changed files with 404 additions and 395 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name,
std::string error_msg;
bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg);
if (!rt) {
// TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more
// specific validation error (instead of "failed to add node").
LOGS(logger_, WARNING) << error_msg;
}
return rt;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_node_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
#include <unordered_map>
#include <vector>

#include "core/common/logging/logging.h"
#include "core/framework/node_unit.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"

namespace onnxruntime {
namespace qnn {

class QnnModelWrapper;

class IQnnNodeGroup {
public:
virtual ~IQnnNodeGroup() = default;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,39 @@
#include <vector>

#include "core/framework/node_unit.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/qnn_node_group.h"

namespace onnxruntime {
namespace qnn {

std::unique_ptr<IQnnNodeGroup> TryConvActivationFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& conv_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const logging::Logger& logger);
class QnnModelWrapper;

namespace conv_act_fusion {

class QnnNodeGroup : public IQnnNodeGroup {
class ConvActivationFusion : public IQnnNodeGroup {
public:
QnnNodeGroup(const NodeUnit& dq_node_unit_0,
const NodeUnit& dq_node_unit_1,
const NodeUnit* dq_node_unit_2,
const NodeUnit& conv_node_unit,
const NodeUnit& activation_node_unit,
const NodeUnit& q_node_unit);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup);
ConvActivationFusion(const NodeUnit& dq_node_unit_0,
const NodeUnit& dq_node_unit_1,
const NodeUnit* dq_node_unit_2,
const NodeUnit& conv_node_unit,
const NodeUnit& activation_node_unit,
const NodeUnit& q_node_unit);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ConvActivationFusion);

Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
const NodeUnit* GetTargetNodeUnit() const override;
std::string_view Type() const override { return "ConvActivationFusion"; }

static std::unique_ptr<IQnnNodeGroup> TryFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& conv_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const logging::Logger& logger);

private:
std::array<const NodeUnit*, 6> node_units_; // Last elem is nullptr if bias DQ is missing.
};

} // namespace conv_act_fusion
} // namespace qnn
} // namespace onnxruntime
175 changes: 114 additions & 61 deletions onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,89 @@
#include <limits>
#include <optional>
#include "core/graph/graph_utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"
#include "core/framework/node_unit.h"
#include "core/providers/shared/utils/utils.h"
#include "core/providers/qnn/builder/qnn_utils.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_node_group/utils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"

namespace onnxruntime {
namespace qnn {

static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit,
const logging::Logger& logger,
bool validate = false) {
// Forward declarations.
#define ValidateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \
CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), true)
#define CreateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \
CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), false)
static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit, bool validate);
static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node);

std::unique_ptr<IQnnNodeGroup> DQQFusion::TryFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const logging::Logger& logger) {
ORT_UNUSED_PARAMETER(logger);
assert(dq_node_unit.OpType() == QDQ::DQOpName && q_node_unit.OpType() == QDQ::QOpName);
// Expect that this function is called with a standalone DQ.
if (dq_node_unit.OpType() != DEQUANTIZE_LINEAR || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
return nullptr;
}

const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
const Node& dq_node = dq_node_unit.GetNode();

// DQ must have a single Q child (1 output edge) and must not produce a graph output.
const std::array<std::string_view, 1> child_types = {QUANTIZE_LINEAR};
const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types,
node_to_node_unit, node_unit_to_qnn_node_group);

if (q_node_unit == nullptr) {
return nullptr;
}

// DQ and Q must have equal scale type and different zp type.
if (!IsDQQConversion(graph_viewer, dq_node, q_node_unit->GetNode())) {
return nullptr;
}

if (Status status = ValidateOnQnn(qnn_model_wrapper, dq_node_unit, *q_node_unit);
!status.IsOK()) {
return nullptr;
}

return std::make_unique<DQQFusion>(dq_node_unit, *q_node_unit);
}

DQQFusion::DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit)
: node_units_{&dq_node_unit, &q_node_unit} {
}

Status DQQFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]);
}

Status DQQFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]);
}

gsl::span<const NodeUnit* const> DQQFusion::GetNodeUnits() const {
return node_units_;
}

const NodeUnit* DQQFusion::GetTargetNodeUnit() const {
return node_units_[0];
}

static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const NodeUnit& q_node_unit,
bool validate) {
assert(dq_node_unit.OpType() == DEQUANTIZE_LINEAR && q_node_unit.OpType() == QUANTIZE_LINEAR);
const auto& node_name = utils::GetNodeName(dq_node_unit);
const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0];
const NodeUnitIODef& output_def = q_node_unit.Outputs()[0];
Expand Down Expand Up @@ -56,70 +122,57 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

std::unique_ptr<IQnnNodeGroup> TryDQQFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const logging::Logger& logger) {
// Expect that this function is called with a standalone DQ.
if (dq_node_unit.OpType() != "DequantizeLinear" || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
return nullptr;
}

const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
const Node& dq_node = dq_node_unit.GetNode();

// DQ must have a single Q child (1 output edge) and must not produce a graph output.
const std::array<std::string_view, 1> child_types = {"QuantizeLinear"};
const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types,
node_to_node_unit, node_unit_to_qnn_node_group);
static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node) {
ConstPointerContainer<std::vector<NodeArg*>> dq_input_defs = dq_node.InputDefs();
ConstPointerContainer<std::vector<NodeArg*>> q_input_defs = q_node.InputDefs();

if (q_node_unit == nullptr) {
return nullptr;
}
auto is_scalar_shape = [](const NodeArg& input_arg) -> bool {
auto shape = input_arg.Shape();
if (shape == nullptr) {
return false;
}

auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) {
return graph_viewer.GetConstantInitializer(initializer_name, true);
auto dim_size = shape->dim_size();
return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1);
};

// DQ and Q must have equal scale type and different zp type.
if (!QDQ::IsDQQConversion(dq_node, q_node_unit->GetNode(), get_const_initializer, graph_viewer.ModelPath())) {
return nullptr;
// Q/DQ contains optional input is not supported
// non-scalar Q/DQ scale and zero point needs are not supported
if (dq_input_defs.size() != QDQ_MAX_NUM_INPUTS ||
q_input_defs.size() != QDQ_MAX_NUM_INPUTS ||
!is_scalar_shape(*q_input_defs[QDQ_SCALE_INPUT_IDX]) ||
!is_scalar_shape(*q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]) ||
!is_scalar_shape(*dq_input_defs[QDQ_SCALE_INPUT_IDX]) ||
!is_scalar_shape(*dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX])) {
return false;
}

if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit,
logger, /*validate*/ true);
!status.IsOK()) {
return nullptr;
// if Q/DQ scale and zero point are not constant, return false
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_SCALE_INPUT_IDX]->Name());
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
graph_viewer.GetConstantInitializer(q_input_defs[QDQ_SCALE_INPUT_IDX]->Name());
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name());
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
graph_viewer.GetConstantInitializer(q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name());
if (nullptr == q_zp_tensor_proto ||
nullptr == dq_zp_tensor_proto ||
nullptr == q_scale_tensor_proto ||
nullptr == dq_scale_tensor_proto) {
return false;
}

std::unique_ptr<IQnnNodeGroup> qnn_node_group = std::make_unique<dq_q_fusion::QnnNodeGroup>(dq_node_unit,
*q_node_unit);
return qnn_node_group;
}

namespace dq_q_fusion {
QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit)
: node_units_{&dq_node_unit, &q_node_unit} {
}

Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true);
}

Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const {
return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false);
}

gsl::span<const NodeUnit* const> QnnNodeGroup::GetNodeUnits() const {
return node_units_;
}
// All TensorProtos must have a data type
if (!q_zp_tensor_proto->has_data_type() || !dq_zp_tensor_proto->has_data_type() ||
!q_scale_tensor_proto->has_data_type() || !dq_scale_tensor_proto->has_data_type()) {
return false;
}

const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const {
return node_units_[0];
// check Q/DQ have same scale type and different zero point type
return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) &&
(dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type());
}

} // namespace dq_q_fusion
} // namespace qnn
} // namespace onnxruntime
50 changes: 24 additions & 26 deletions onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,46 @@

#include "core/common/common.h"
#include "core/framework/node_unit.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/qnn_node_group.h"

namespace onnxruntime {
namespace qnn {

/**
* Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from
* one quantization type (e.g., uint8_t) to another (e.g., uint16_t).
*
* \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied.
* \param qnn_model_wrapper The QNN model that is being built.
* \param dq_node_unit The DQ node unit.
* \param q_node_unit The Q node unit.
* \param logger The logger.
* \param do_op_validation True if should call QNN operator validation APIs.
* \return An onnxruntime::Status
*/
std::unique_ptr<IQnnNodeGroup> TryDQQFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const logging::Logger& logger);

namespace dq_q_fusion {

class QnnNodeGroup : public IQnnNodeGroup {
class QnnModelWrapper;

class DQQFusion : public IQnnNodeGroup {
public:
QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup);
DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit);
ORT_DISALLOW_COPY_AND_ASSIGNMENT(DQQFusion);

Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override;
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
const NodeUnit* GetTargetNodeUnit() const override;
std::string_view Type() const override { return "DQQFusion"; }

/**
* Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from
* one quantization type (e.g., uint8_t) to another (e.g., uint16_t).
*
* \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied.
* \param qnn_model_wrapper The QNN model that is being built.
* \param dq_node_unit The DQ node unit.
* \param q_node_unit The Q node unit.
* \param logger The logger.
* \param do_op_validation True if should call QNN operator validation APIs.
* \return An onnxruntime::Status
*/
static std::unique_ptr<IQnnNodeGroup> TryFusion(
QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& dq_node_unit,
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
const logging::Logger& logger);

private:
std::array<const NodeUnit*, 2> node_units_;
};

} // namespace dq_q_fusion
} // namespace qnn
} // namespace onnxruntime
Loading

0 comments on commit 0e7eece

Please sign in to comment.