Skip to content

Commit

Permalink
Use different NodeUnit constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Jul 27, 2024
1 parent dd8dc3d commit fa79a2f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 53 deletions.
47 changes: 12 additions & 35 deletions onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,41 +272,18 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g
}
}

NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group,
const Node& output_activation_node)
: dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
target_node_(*graph_viewer.GetNode(node_group.target_node)),
q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(output_activation_node, node_group, false /* is_input */)} {
input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });

// add edges for inputs that are not from DQ nodes. there is one edge to each DQ node.
// other inputs could come from initializers or graph inputs (no edges) or other nodes (edge).
input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size();

// create output edges. each target node output either goes to Q node/s or non-Q node/s.
// ValidateNodeGroupQDQNodes ensures this.
auto cur_edge = output_activation_node.OutputEdgesBegin();
auto end_edge = output_activation_node.OutputEdgesEnd();
for (; cur_edge != end_edge; ++cur_edge) {
const Node& node = cur_edge->GetNode();

// if node is in q_nodes we hide the Q node.
if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) {
auto src_idx = cur_edge->GetSrcArgIndex();
auto q_cur_edge = node.OutputEdgesBegin();
auto q_end_edge = node.OutputEdgesEnd();
for (; q_cur_edge != q_end_edge; ++q_cur_edge) {
output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()});
}
} else {
// non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is.
output_edges_.insert(*cur_edge);
}
}
NodeUnit::NodeUnit(std::vector<const Node*> dq_nodes, const Node& target_node,
std::vector<const Node*> q_nodes, Type type,
std::vector<NodeUnitIODef> inputs, std::vector<NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges)
: dq_nodes_(std::move(dq_nodes)),
target_node_(target_node),
q_nodes_(std::move(q_nodes)),
type_(type),
inputs_(std::move(inputs)),
outputs_(std::move(outputs)),
input_edge_count_(input_edge_count),
output_edges_(std::move(output_edges)) {
}

const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); }
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ class NodeUnit {
public:
explicit NodeUnit(const Node& node);
explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group);
explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group,
const Node& output_activation_node);
NodeUnit(std::vector<const Node*> dq_nodes, const Node& target_node,
std::vector<const Node*> q_nodes, Type type,
std::vector<NodeUnitIODef> inputs, std::vector<NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges);

Type UnitType() const noexcept { return type_; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,60 @@ static bool IsValidQDQConv(gsl::span<const NodeUnit*> dq_node_units,
Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper,
gsl::span<const NodeUnit*> dq_node_units,
const NodeUnit* conv_node_unit,
const NodeUnit* activation_node_unit,
const NodeUnit* q_node_unit,
const logging::Logger& logger,
bool validate) {
QDQ::NodeGroup custom_node_group;
custom_node_group.dq_nodes.reserve(dq_node_units.size());
custom_node_group.q_nodes = std::vector<NodeIndex>{q_node_unit->Index()};
custom_node_group.target_node = conv_node_unit->Index();
auto get_node_idx = [](const NodeUnit* n) { return n->Index(); };
std::transform(dq_node_units.begin(), dq_node_units.end(), std::back_inserter(custom_node_group.dq_nodes),
get_node_idx);

NodeUnit custom_node_unit(qnn_model_wrapper.GetGraphViewer(), custom_node_group, activation_node_unit->GetNode());
std::vector<const Node*> dq_nodes;
dq_nodes.reserve(dq_node_units.size());
for (const NodeUnit* dq_node_unit : dq_node_units) {
dq_nodes.push_back(&dq_node_unit->GetNode());
}
std::vector<const Node*> q_nodes = {&q_node_unit->GetNode()};
const Node& target_node = conv_node_unit->GetNode();

// Populate NodeUnit inputs
std::vector<NodeUnitIODef> inputs;
inputs.reserve(dq_node_units.size());
for (const Node* dq_node : dq_nodes) {
const auto dq_inputs = dq_node->InputDefs();
const auto& dq_attrs = dq_node->GetAttributes();

std::optional<int64_t> axis;
if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) {
axis = entry->second.i();
}

// quantization scale and zp are always the input[1, 2]
NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis};
inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param});
}

// Populate NodeUnit outputs and output edges
std::vector<NodeUnitIODef> outputs;
Node::EdgeSet output_edges;
for (const Node* q_node : q_nodes) {
const auto q_inputs = q_node->InputDefs();
const auto& q_attrs = q_node->GetAttributes();
const auto q_outputs = q_node->OutputDefs();

std::optional<int64_t> axis;
if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) {
axis = entry->second.i();
}

// quantization scale and zp are always the input[1, 2]
NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis};
outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param});

auto q_cur_edge = q_node->OutputEdgesBegin();
auto q_end_edge = q_node->OutputEdgesEnd();
for (; q_cur_edge != q_end_edge; ++q_cur_edge) {
output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()});
}
}

NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup,
inputs, outputs, dq_nodes.size(), output_edges);
const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType());
if (conv_op_builder == nullptr) {
return Status::OK();
Expand Down Expand Up @@ -463,7 +504,6 @@ std::optional<QnnNodeGroup> TryConvActivationFusion(QnnModelWrapper& qnn_model_w
if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper,
dq_node_units,
&conv_node_unit,
activation_node_unit,
q_node_unit,
logger,
/*validate*/ true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace qnn {
Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper,
gsl::span<const NodeUnit*> dq_node_units,
const NodeUnit* conv_node_unit,
const NodeUnit* activation_node_unit,
const NodeUnit* q_node_unit,
const logging::Logger& logger,
bool validate = false);
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/providers/qnn/builder/qnn_fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& lo
Status status = QnnConvActivationFusionAdd(qmw,
dq_node_units,
conv_node_unit,
activation_node_unit,
q_node_unit,
logger,
/*validate*/ true);
Expand Down Expand Up @@ -225,7 +224,6 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg
const bool has_bias_dq = num_node_units == 6;
std::vector<const NodeUnit*> dq_node_units = {node_units_[0], node_units_[1]};
const NodeUnit* conv_node_unit = node_units_[num_node_units - 3];
const NodeUnit* activation_node_unit = node_units_[num_node_units - 2];
const NodeUnit* q_node_unit = node_units_[num_node_units - 1];

if (has_bias_dq) {
Expand All @@ -234,7 +232,6 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg
return QnnConvActivationFusionAdd(qmw,
dq_node_units,
conv_node_unit,
activation_node_unit,
q_node_unit,
logger,
/*validate*/ false);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/qnn/qnn_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ TEST_F(QnnHTPBackendTests, TestOD) {

#if 1
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx";
//so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1");
#else
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx";
#endif
Expand Down

0 comments on commit fa79a2f

Please sign in to comment.