Skip to content

Commit

Permalink
[Model Optimizer] Add rule to remove Resize nodes with unity Resize s…
Browse files Browse the repository at this point in the history
…cale (s=1's)
  • Loading branch information
reesegrimsley committed Aug 15, 2024
1 parent e102ef2 commit ed2703a
Showing 1 changed file with 40 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import onnx_graphsurgeon as gs
import onnx
import numpy as np
from tidl_onnx_model_optimizer.src.common import find_out_layers, remove_node


def tidl_convert_resize_params_size_to_scale(graph: gs.Graph,
Expand Down Expand Up @@ -102,3 +103,42 @@ def tidl_convert_resize_params_size_to_scale(graph: gs.Graph,
# endif
# endif
# endfor


def tidl_remove_unity_resize(graph: gs.Graph,
onnx_graph: onnx.GraphProto):
'''
Some models have an effectively null resize node that scales by a factor of 1 in all dimensions
Such a node is often an export artifact -- a layer added by a model format converter
This is node effectively unity, but it will be processed nonetheless. It should therefore be removed
'''

tensors = graph.tensors()
nodes_to_remove = []
for node in graph.nodes:

if node.op == "Resize":
inputs = node.inputs
if len(inputs) >= 3:
X, roi, scales = inputs[0:3]
else:
continue
Y = node.outputs[0]
attrs = node.attrs

if X.shape == Y.shape and all(map(lambda x: x==1, scales.values)):
#ensure it's not using ROI, which is only with crop-and-resize mode
if node.attrs['coordinate_transformation_mode'] == 'tf_crop_and_resize':
logging.warning("Detected Resize node as using ROI... skipping")
continue

logging.debug("Removing unity Resize node %s" % node.name)

out_nodes = find_out_layers(node)

for o_node in out_nodes:
for i, net in enumerate(o_node.inputs):
if net == Y:
o_node.inputs[i] = X

#node will be removed by cleanup since it has only unused outputs

0 comments on commit ed2703a

Please sign in to comment.