From ed2703a005830c789c574d88eeb84f57e18ef4d3 Mon Sep 17 00:00:00 2001 From: Reese Grimsley Date: Thu, 15 Aug 2024 11:52:09 -0500 Subject: [PATCH] [Model Optimizer] Add rule to remove Resize nodes with unity Resize scale (s=1's) --- .../tidl_onnx_model_optimizer/src/resize.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/scripts/osrt_model_tools/onnx_tools/tidl-onnx-model-optimizer/tidl_onnx_model_optimizer/src/resize.py b/scripts/osrt_model_tools/onnx_tools/tidl-onnx-model-optimizer/tidl_onnx_model_optimizer/src/resize.py index c3d1470..62aae37 100644 --- a/scripts/osrt_model_tools/onnx_tools/tidl-onnx-model-optimizer/tidl_onnx_model_optimizer/src/resize.py +++ b/scripts/osrt_model_tools/onnx_tools/tidl-onnx-model-optimizer/tidl_onnx_model_optimizer/src/resize.py @@ -63,6 +63,7 @@ import onnx_graphsurgeon as gs import onnx import numpy as np +from tidl_onnx_model_optimizer.src.common import find_out_layers, remove_node def tidl_convert_resize_params_size_to_scale(graph: gs.Graph, @@ -102,3 +103,42 @@ def tidl_convert_resize_params_size_to_scale(graph: gs.Graph, # endif # endif # endfor + + +def tidl_remove_unity_resize(graph: gs.Graph, + onnx_graph: onnx.GraphProto): + ''' + Some models have an effectively null resize node that scales by a factor of 1 in all dimensions + Such a node is often an export artifact -- a layer added by a model format converter + This is node effectively unity, but it will be processed nonetheless. It should therefore be removed + ''' + + tensors = graph.tensors() + nodes_to_remove = [] + for node in graph.nodes: + + if node.op == "Resize": + inputs = node.inputs + if len(inputs) >= 3: + X, roi, scales = inputs[0:3] + else: + continue + Y = node.outputs[0] + attrs = node.attrs + + if X.shape == Y.shape and all(map(lambda x: x==1, scales.values)): + #ensure it's not using ROI, which is only with crop-and-resize mode + if node.attrs['coordinate_transformation_mode'] == 'tf_crop_and_resize': + logging.warning("Detected Resize node as using ROI... skipping") + continue + + logging.debug("Removing unity Resize node %s" % node.name) + + out_nodes = find_out_layers(node) + + for o_node in out_nodes: + for i, net in enumerate(o_node.inputs): + if net == Y: + o_node.inputs[i] = X + + #node will be removed by cleanup since it has only unused outputs