From 8902694106de98c827e38e04ffbf3f0d8dfc9675 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 15 Nov 2023 09:41:18 +0100 Subject: [PATCH 1/6] Fix RemoveIdentityOps not correctly handling ops following fork-nodes --- src/qonnx/transformation/remove.py | 45 ++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index e745f0f0..2fc888cb 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -25,9 +25,8 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - import numpy as np +import warnings from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation @@ -58,21 +57,45 @@ def apply(self, model: ModelWrapper): def remove_node_and_rewire(model, node): + # Currently cannot remove and rewire join-nodes, probably not necessary to + # support this + if model.is_join_node(node): + # Log this as a warning, so the user is aware of this, there might be + # somthing wrong or some checks missing at the caller site + warnings.warn( + "Tried to remove join-node operation: Currently not supported" + ) + # Exit the function here without doing anything + return + # We already know that node is not a join-node, thus to rewire, we only need + # to check the single producer producer = model.find_producer(node.input[0]) - if producer is not None: - # wire output tensor to - # output of producer node + # If there is a producer which is not a fork-node, rewiring is simple + if producer is not None and not model.is_fork_node(producer): + # Rewire by skipping the node, letting the producer directly feed the + # nodes output. + # TODO: Check whether this already covers fork-node identities? producer.output[0] = node.output[0] + # If there is no producer or the producer forks, rewiring is a bit more + # complicated else: - # node is first in graph + # Now it depends on the successor nodes to rewire their inputs successors = model.find_direct_successors(node) + # Singular node detached from the rest of the graph? assert successors is not None, "Whole graph is one node." - for succ in successors: - for i, s_inp in enumerate(succ.input): + # We need to rewire the input of each successor to not detach parts of + # the graph + for successor in successors: + # Find the inputs of the successor which are produced by the node to + # be removed + for i, s_inp in enumerate(successor.input): + # Note: This might happen multiple times? if s_inp == node.output[0]: - # rewire successor's input directly to graph input - succ.input[i] = node.input[0] - # remove node + # Rewire successor's input directly to nodes input + # Note: Node may not be a join-node, but there is probably + # no such thing as join-node identity anyway + successor.input[i] = node.input[0] + # Remove node model.graph.node.remove(node) From c7b359062dee8b979bc22741885ac812da8fe7ce Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Wed, 15 Nov 2023 09:50:20 +0100 Subject: [PATCH 2/6] Change error message to address some linting issue --- src/qonnx/transformation/remove.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 2fc888cb..980e80c1 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -62,9 +62,7 @@ def remove_node_and_rewire(model, node): if model.is_join_node(node): # Log this as a warning, so the user is aware of this, there might be # somthing wrong or some checks missing at the caller site - warnings.warn( - "Tried to remove join-node operation: Currently not supported" - ) + warnings.warn("Removing join-node operation is currently not supported") # Exit the function here without doing anything return # We already know that node is not a join-node, thus to rewire, we only need From 8bad7e71806d6c611c68fe00ac6007b076b08b5f Mon Sep 17 00:00:00 2001 From: jvreca Date: Thu, 22 Aug 2024 17:06:23 +0200 Subject: [PATCH 3/6] Added Identity node to the removal list --- src/qonnx/transformation/remove.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/transformation/remove.py b/src/qonnx/transformation/remove.py index 980e80c1..0f7f38f7 100644 --- a/src/qonnx/transformation/remove.py +++ b/src/qonnx/transformation/remove.py @@ -138,5 +138,9 @@ def apply(self, model): remove_node_and_rewire(model, n) graph_modified = True break + elif n.op_type == "Identity": + remove_node_and_rewire(model, n) + graph_modified = True + break model = model.transform(InferShapes()) return (model, graph_modified) From 71ee78062ebdb5ae58dfbcc644d97d07dff3beb1 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 12 Sep 2024 10:42:11 +0300 Subject: [PATCH 4/6] [Test] add Identity op case to test_remove_identity_ops --- tests/transformation/test_remove_identity_ops.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index ed34ffe6..d9e92c73 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -51,25 +51,30 @@ def insert_identity_op(model, op, as_first_node, approx): val = np.asarray([zero_val], dtype=np.float32) elif op in ["Mul", "Div"]: val = np.asarray([one_val], dtype=np.float32) + elif op in ["Identity"]: + val = None else: return graph = model.graph + if val is None: + inplist = ["inp" if as_first_node else "div_out"] + else: + model.set_initializer("value", val) + inplist = ["inp" if as_first_node else "div_out", "value"] + identity_node = helper.make_node(op, inplist, ["ident_out"]) if as_first_node: - identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"]) graph.node.insert(0, identity_node) graph.node[1].input[0] = "ident_out" else: - identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"]) graph.node.insert(3, identity_node) graph.node[-1].input[0] = "ident_out" - model.set_initializer("value", val) return model # identity operations to be inserted -@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) def test_remove_identity_ops(op, as_first_node, approx): From 0a4d5c5315082582d3a646e9504fe129b4ff0fd6 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 12 Sep 2024 12:01:18 +0300 Subject: [PATCH 5/6] [ModelWrapper] add top-level checks for fork/join checks --- src/qonnx/core/modelwrapper.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index b95c6a33..779bb8f2 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -429,14 +429,24 @@ def is_fork_node(self, node): """Checks if the given node is a fork, that is, the node has multiple direct successors""" direct_successors = self.find_direct_successors(node) - is_fork = False if direct_successors is None else (len(direct_successors) > 1) + # if the node output is also wired to a top-level output, it is still + # a fork with only 1 direct successor + if node.output[0] in [x.name for x in self.graph.output]: + is_fork = False if direct_successors is None else (len(direct_successors) > 0) + else: + is_fork = False if direct_successors is None else (len(direct_successors) > 1) return is_fork def is_join_node(self, node): """Checks if the given node is a join, that is, the node has multiple direct predecessors""" direct_predecessors = self.find_direct_predecessors(node) - is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1) + # if the node input is also wired to a top-level input, it is still + # a fork with only 1 direct predecessor + if node.input[0] in [x.name for x in self.graph.input]: + is_join = False if direct_predecessors is None else (len(direct_predecessors) > 0) + else: + is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1) return is_join def get_all_tensor_names(self): From 2d0934111ad24928aa3a613f7262b835a0d135c3 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu Date: Thu, 12 Sep 2024 12:01:54 +0300 Subject: [PATCH 6/6] [Test] add fork cases to RemoveIdentityOps test --- .../transformation/test_remove_identity_ops.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/transformation/test_remove_identity_ops.py b/tests/transformation/test_remove_identity_ops.py index d9e92c73..cfe01a82 100644 --- a/tests/transformation/test_remove_identity_ops.py +++ b/tests/transformation/test_remove_identity_ops.py @@ -77,7 +77,8 @@ def insert_identity_op(model, op, as_first_node, approx): @pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"]) @pytest.mark.parametrize("approx", [False, True]) @pytest.mark.parametrize("as_first_node", [False, True]) -def test_remove_identity_ops(op, as_first_node, approx): +@pytest.mark.parametrize("fork_before_id", [False, True]) +def test_remove_identity_ops(op, as_first_node, approx, fork_before_id): # set up onnx model inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1]) mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, []) @@ -114,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx): model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) idict = {"inp": inp_values} - odict = oxe.execute_onnx(model, idict) - out_before = odict["outp"] + odict_before = oxe.execute_onnx(model, idict) num_of_nodes_before = len(model.graph.node) - + if fork_before_id and not as_first_node: + divout_vi = model.get_tensor_valueinfo("div_out") + model.graph.output.append(divout_vi) + model.graph.value_info.remove(divout_vi) model = model.transform(RemoveIdentityOps()) num_of_nodes_after = len(model.graph.node) assert num_of_nodes_before - 1 == num_of_nodes_after - odict = oxe.execute_onnx(model, idict) - out_after = odict["outp"] - assert np.isclose(out_before, out_after, atol=1e-3).all() + odict_after = oxe.execute_onnx(model, idict) + outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()] + assert all(outputs_same)