Skip to content

Commit

Permalink
[QNN Quantization] Ensure fused nodes have names (#19650)
Browse files Browse the repository at this point in the history
### Description
- Updates the `qnn_preprocess_model()` method to set a name for any new
nodes added to the graph (due to fusion).
- Updates the `qnn_preprocess_model()` method to set a name for any
unnamed nodes that previously existed in the original graph.
- Adds unit tests for fusions (previously missing)
  - Checks that fused node names exist and are unique
  - Checks that fused graph is equivalent to original graph


### Motivation and Context
Nodes are not strictly required to have names. However, a
planned/upcoming feature to support mixed-precision (integer) quantized
models needs nodes to have names.
  • Loading branch information
adrianlizarraga authored Feb 27, 2024
1 parent 1e69b61 commit 4838cb6
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def fuse(

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1
self.fused_op_type,
name=self.create_unique_node_name(),
inputs=[subgraph_input],
outputs=[subgraph_output],
p=2,
axis=-1,
)
self.nodes_to_add.append(fused_node)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm:
if fusion_layernorm.apply():
modified = True

# Make sure all nodes have a name.
unnamed_node_prefix = "qnn_preproc_node_"
available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1
for node in onnx_model.model.graph.node:
if node.op_type != "Constant" and not node.name:
new_node_name = f"{unnamed_node_prefix}{available_suffix!s}"
available_suffix += 1
node.name = new_node_name
modified = True
logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.")

if modified:
onnx_model.topological_sort()
onnx.save_model(model, model_output)
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/python/tools/quantization/fusions/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
self.nodes_to_remove: list = []
self.nodes_to_add: list = []

self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.

def fuse(
self,
node: onnx.NodeProto,
Expand Down Expand Up @@ -57,6 +60,18 @@ def apply(self) -> bool:

return graph_updated

def create_unique_node_name(self):
prefix = self._new_node_name_prefix

if self._new_node_name_suffix is None:
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
self._new_node_name_suffix = largest_suffix + 1

new_name = f"{prefix}{self._new_node_name_suffix!s}"
self._new_node_name_suffix += 1

return new_name

@staticmethod
def is_safe_to_fuse_nodes(
nodes_to_remove: list[onnx.NodeProto],
Expand Down
25 changes: 14 additions & 11 deletions onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def fuse_1(
return False

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
Expand Down Expand Up @@ -173,11 +175,9 @@ def fuse_2(
if not self.has_constant_input(sqrt_node, 2.0):
return False

root_node = self.model.get_parent(div, 0, output_name_to_node)
if root_node is None:
return False
subgraph_input = div.input[0]

if root_node.output[0] not in mul.input:
if subgraph_input not in mul.input:
return False

subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
Expand All @@ -188,7 +188,9 @@ def fuse_2(
return False

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
Expand Down Expand Up @@ -239,9 +241,8 @@ def fuse_3(
if i < 0:
return False

root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
if root_node is None:
return False
root_input_index = 1 - i
subgraph_input = first_mul.input[root_input_index]

if mul_half.output[0] not in input_name_to_nodes:
return False
Expand All @@ -250,7 +251,7 @@ def fuse_3(
return False
last_mul = children[0]

if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
return False

subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
Expand All @@ -263,7 +264,9 @@ def fuse_3(
return False

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def fuse(

normalize_node = onnx.helper.make_node(
"LayerNormalization",
name=self.create_unique_node_name(),
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]],
)
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/python/tools/quantization/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph):
node = find_by_name(node_name, graph_nodes_list)
return node

def get_largest_node_name_suffix(self, node_name_prefix):
"""
Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`.
Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3.
"""
suffix = -1

for node in self.model.graph.node:
if node.name and node.name.startswith(node_name_prefix):
try:
index = int(node.name[len(node_name_prefix) :])
suffix = max(index, suffix)
except ValueError:
continue

return suffix

def find_nodes_by_initializer(self, graph, initializer):
"""
Find all nodes with given initializer as an input.
Expand Down
Loading

0 comments on commit 4838cb6

Please sign in to comment.