diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index b95c6a33..779bb8f2 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -429,14 +429,24 @@ def is_fork_node(self, node): """Checks if the given node is a fork, that is, the node has multiple direct successors""" direct_successors = self.find_direct_successors(node) - is_fork = False if direct_successors is None else (len(direct_successors) > 1) + # if the node output is also wired to a top-level output, it is still + # a fork with only 1 direct successor + if node.output[0] in [x.name for x in self.graph.output]: + is_fork = False if direct_successors is None else (len(direct_successors) > 0) + else: + is_fork = False if direct_successors is None else (len(direct_successors) > 1) return is_fork def is_join_node(self, node): """Checks if the given node is a join, that is, the node has multiple direct predecessors""" direct_predecessors = self.find_direct_predecessors(node) - is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1) + # if the node input is also wired to a top-level input, it is still + # a fork with only 1 direct predecessor + if node.input[0] in [x.name for x in self.graph.input]: + is_join = False if direct_predecessors is None else (len(direct_predecessors) > 0) + else: + is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1) return is_join def get_all_tensor_names(self): diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index e745f0f0..0f7f38f7 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -25,9 +25,8 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - import numpy as np +import warnings from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation @@ -58,21 +57,43 @@ def apply(self, model: ModelWrapper): def remove_node_and_rewire(model, node): + # Currently cannot remove and rewire join-nodes, probably not necessary to + # support this + if model.is_join_node(node): + # Log this as a warning, so the user is aware of this, there might be + # somthing wrong or some checks missing at the caller site + warnings.warn("Removing join-node operation is currently not supported") + # Exit the function here without doing anything + return + # We already know that node is not a join-node, thus to rewire, we only need + # to check the single producer producer = model.find_producer(node.input[0]) - if producer is not None: - # wire output tensor to - # output of producer node + # If there is a producer which is not a fork-node, rewiring is simple + if producer is not None and not model.is_fork_node(producer): + # Rewire by skipping the node, letting the producer directly feed the + # nodes output. + # TODO: Check whether this already covers fork-node identities? producer.output[0] = node.output[0] + # If there is no producer or the producer forks, rewiring is a bit more + # complicated else: - # node is first in graph + # Now it depends on the successor nodes to rewire their inputs successors = model.find_direct_successors(node) + # Singular node detached from the rest of the graph? assert successors is not None, "Whole graph is one node." - for succ in successors: - for i, s_inp in enumerate(succ.input): + # We need to rewire the input of each successor to not detach parts of + # the graph + for successor in successors: + # Find the inputs of the successor which are produced by the node to + # be removed + for i, s_inp in enumerate(successor.input): + # Note: This might happen multiple times? if s_inp == node.output[0]: - # rewire successor's input directly to graph input - succ.input[i] = node.input[0] - # remove node + # Rewire successor's input directly to nodes input + # Note: Node may not be a join-node, but there is probably + # no such thing as join-node identity anyway + successor.input[i] = node.input[0] + # Remove node model.graph.node.remove(node) @@ -117,5 +138,9 @@ def apply(self, model): remove_node_and_rewire(model, n) graph_modified = True break + elif n.op_type == "Identity": + remove_node_and_rewire(model, n) + graph_modified = True + break model = model.transform(InferShapes()) return (model, graph_modified) diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index ed34ffe6..cfe01a82 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -51,28 +51,34 @@ def insert_identity_op(model, op, as_first_node, approx): val = np.asarray([zero_val], dtype=np.float32) elif op in ["Mul", "Div"]: val = np.asarray([one_val], dtype=np.float32) + elif op in ["Identity"]: + val = None else: return graph = model.graph + if val is None: + inplist = ["inp" if as_first_node else "div_out"] + else: + model.set_initializer("value", val) + inplist = ["inp" if as_first_node else "div_out", "value"] + identity_node = helper.make_node(op, inplist, ["ident_out"]) if as_first_node: - identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"]) graph.node.insert(0, identity_node) graph.node[1].input[0] = "ident_out" else: - identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"]) graph.node.insert(3, identity_node) graph.node[-1].input[0] = "ident_out" - model.set_initializer("value", val) return model # identity operations to be inserted -@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) -def test_remove_identity_ops(op, as_first_node, approx): +@pytest.mark.parametrize("fork_before_id", [False, True]) +def test_remove_identity_ops(op, as_first_node, approx, fork_before_id): # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, []) @@ -109,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values} - odict = oxe.execute_onnx(model, idict) - out_before = odict["outp"] + odict_before = oxe.execute_onnx(model, idict) num_of_nodes_before = len(model.graph.node) - + if fork_before_id and not as_first_node: + divout_vi = model.get_tensor_valueinfo("div_out") + model.graph.output.append(divout_vi) + model.graph.value_info.remove(divout_vi) model = model.transform(RemoveIdentityOps()) num_of_nodes_after = len(model.graph.node) assert num_of_nodes_before - 1 == num_of_nodes_after - odict = oxe.execute_onnx(model, idict) - out_after = odict["outp"] - assert np.isclose(out_before, out_after, atol=1e-3).all() + odict_after = oxe.execute_onnx(model, idict) + outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()] + assert all(outputs_same)