Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN Quantization] Ensure fused nodes have names #19650

Merged
merged 9 commits into from
Feb 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def fuse(

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node(
self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1
self.fused_op_type,
name=self.create_unique_node_name(),
inputs=[subgraph_input],
outputs=[subgraph_output],
p=2,
axis=-1,
)
self.nodes_to_add.append(fused_node)
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm:
if fusion_layernorm.apply():
modified = True

# Make sure all nodes have a name.
unnamed_node_prefix = "qnn_preproc_node_"
available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1
for node in onnx_model.model.graph.node:
if node.op_type != "Constant" and not node.name:
new_node_name = f"{unnamed_node_prefix}{available_suffix!s}"
available_suffix += 1
node.name = new_node_name
modified = True
logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.")

if modified:
onnx_model.topological_sort()
onnx.save_model(model, model_output)
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/python/tools/quantization/fusions/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
self.nodes_to_remove: list = []
self.nodes_to_add: list = []

self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.

def fuse(
self,
node: onnx.NodeProto,
Expand Down Expand Up @@ -57,6 +60,18 @@ def apply(self) -> bool:

return graph_updated

def create_unique_node_name(self):
prefix = self._new_node_name_prefix

if self._new_node_name_suffix is None:
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
self._new_node_name_suffix = largest_suffix + 1

new_name = f"{prefix}{self._new_node_name_suffix!s}"
self._new_node_name_suffix += 1

return new_name

@staticmethod
def is_safe_to_fuse_nodes(
nodes_to_remove: list[onnx.NodeProto],
Expand Down
25 changes: 14 additions & 11 deletions onnxruntime/python/tools/quantization/fusions/fusion_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def fuse_1(
return False

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output])
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
Expand Down Expand Up @@ -173,11 +175,9 @@ def fuse_2(
if not self.has_constant_input(sqrt_node, 2.0):
return False

root_node = self.model.get_parent(div, 0, output_name_to_node)
if root_node is None:
return False
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
subgraph_input = div.input[0]

if root_node.output[0] not in mul.input:
if subgraph_input not in mul.input:
return False

subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
Expand All @@ -188,7 +188,9 @@ def fuse_2(
return False

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]])
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
Expand Down Expand Up @@ -239,9 +241,8 @@ def fuse_3(
if i < 0:
return False

root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
if root_node is None:
return False
root_input_index = 1 - i
subgraph_input = first_mul.input[root_input_index]

if mul_half.output[0] not in input_name_to_nodes:
return False
Expand All @@ -250,7 +251,7 @@ def fuse_3(
return False
last_mul = children[0]

if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
return False

subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
Expand All @@ -263,7 +264,9 @@ def fuse_3(
return False

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]])
fused_node = onnx.helper.make_node(
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def fuse(

normalize_node = onnx.helper.make_node(
"LayerNormalization",
name=self.create_unique_node_name(),
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
outputs=[last_add_node.output[0]],
)
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/python/tools/quantization/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph):
node = find_by_name(node_name, graph_nodes_list)
return node

def get_largest_node_name_suffix(self, node_name_prefix):
"""
Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`.
Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3.
"""
suffix = -1

for node in self.model.graph.node:
if node.name and node.name.startswith(node_name_prefix):
try:
index = int(node.name[len(node_name_prefix) :])
suffix = max(index, suffix)
except ValueError:
continue

return suffix

def find_nodes_by_initializer(self, graph, initializer):
"""
Find all nodes with given initializer as an input.
Expand Down
Loading
Loading