Skip to content

Commit

Permalink
Merge pull request fastmachinelearning#87 from iksnagreb/fix/remove_i…
Browse files Browse the repository at this point in the history
…dentity_ops

Fix RemoveIdentityOps not correctly handling ops following fork-nodes
  • Loading branch information
maltanar authored Sep 12, 2024
2 parents e02f701 + 2d09341 commit 88a679a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 24 deletions.
14 changes: 12 additions & 2 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,24 @@ def is_fork_node(self, node):
"""Checks if the given node is a fork, that is, the node has multiple
direct successors"""
direct_successors = self.find_direct_successors(node)
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
# if the node output is also wired to a top-level output, it is still
# a fork with only 1 direct successor
if node.output[0] in [x.name for x in self.graph.output]:
is_fork = False if direct_successors is None else (len(direct_successors) > 0)
else:
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
return is_fork

def is_join_node(self, node):
"""Checks if the given node is a join, that is, the node has multiple
direct predecessors"""
direct_predecessors = self.find_direct_predecessors(node)
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1)
# if the node input is also wired to a top-level input, it is still
# a fork with only 1 direct predecessor
if node.input[0] in [x.name for x in self.graph.input]:
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 0)
else:
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1)
return is_join

def get_all_tensor_names(self):
Expand Down
47 changes: 36 additions & 11 deletions src/qonnx/transformation/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import numpy as np
import warnings

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.base import Transformation
Expand Down Expand Up @@ -58,21 +57,43 @@ def apply(self, model: ModelWrapper):


def remove_node_and_rewire(model, node):
# Currently cannot remove and rewire join-nodes, probably not necessary to
# support this
if model.is_join_node(node):
# Log this as a warning, so the user is aware of this, there might be
# somthing wrong or some checks missing at the caller site
warnings.warn("Removing join-node operation is currently not supported")
# Exit the function here without doing anything
return
# We already know that node is not a join-node, thus to rewire, we only need
# to check the single producer
producer = model.find_producer(node.input[0])
if producer is not None:
# wire output tensor to
# output of producer node
# If there is a producer which is not a fork-node, rewiring is simple
if producer is not None and not model.is_fork_node(producer):
# Rewire by skipping the node, letting the producer directly feed the
# nodes output.
# TODO: Check whether this already covers fork-node identities?
producer.output[0] = node.output[0]
# If there is no producer or the producer forks, rewiring is a bit more
# complicated
else:
# node is first in graph
# Now it depends on the successor nodes to rewire their inputs
successors = model.find_direct_successors(node)
# Singular node detached from the rest of the graph?
assert successors is not None, "Whole graph is one node."
for succ in successors:
for i, s_inp in enumerate(succ.input):
# We need to rewire the input of each successor to not detach parts of
# the graph
for successor in successors:
# Find the inputs of the successor which are produced by the node to
# be removed
for i, s_inp in enumerate(successor.input):
# Note: This might happen multiple times?
if s_inp == node.output[0]:
# rewire successor's input directly to graph input
succ.input[i] = node.input[0]
# remove node
# Rewire successor's input directly to nodes input
# Note: Node may not be a join-node, but there is probably
# no such thing as join-node identity anyway
successor.input[i] = node.input[0]
# Remove node
model.graph.node.remove(node)


Expand Down Expand Up @@ -117,5 +138,9 @@ def apply(self, model):
remove_node_and_rewire(model, n)
graph_modified = True
break
elif n.op_type == "Identity":
remove_node_and_rewire(model, n)
graph_modified = True
break
model = model.transform(InferShapes())
return (model, graph_modified)
30 changes: 19 additions & 11 deletions tests/transformation/test_remove_identity_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,34 @@ def insert_identity_op(model, op, as_first_node, approx):
val = np.asarray([zero_val], dtype=np.float32)
elif op in ["Mul", "Div"]:
val = np.asarray([one_val], dtype=np.float32)
elif op in ["Identity"]:
val = None
else:
return

graph = model.graph
if val is None:
inplist = ["inp" if as_first_node else "div_out"]
else:
model.set_initializer("value", val)
inplist = ["inp" if as_first_node else "div_out", "value"]
identity_node = helper.make_node(op, inplist, ["ident_out"])
if as_first_node:
identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"])
graph.node.insert(0, identity_node)
graph.node[1].input[0] = "ident_out"
else:
identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
graph.node.insert(3, identity_node)
graph.node[-1].input[0] = "ident_out"
model.set_initializer("value", val)

return model


# identity operations to be inserted
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
@pytest.mark.parametrize("approx", [False, True])
@pytest.mark.parametrize("as_first_node", [False, True])
def test_remove_identity_ops(op, as_first_node, approx):
@pytest.mark.parametrize("fork_before_id", [False, True])
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
# set up onnx model
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
Expand Down Expand Up @@ -109,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
idict = {"inp": inp_values}
odict = oxe.execute_onnx(model, idict)
out_before = odict["outp"]
odict_before = oxe.execute_onnx(model, idict)
num_of_nodes_before = len(model.graph.node)

if fork_before_id and not as_first_node:
divout_vi = model.get_tensor_valueinfo("div_out")
model.graph.output.append(divout_vi)
model.graph.value_info.remove(divout_vi)
model = model.transform(RemoveIdentityOps())
num_of_nodes_after = len(model.graph.node)
assert num_of_nodes_before - 1 == num_of_nodes_after

odict = oxe.execute_onnx(model, idict)
out_after = odict["outp"]
assert np.isclose(out_before, out_after, atol=1e-3).all()
odict_after = oxe.execute_onnx(model, idict)
outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()]
assert all(outputs_same)

0 comments on commit 88a679a

Please sign in to comment.