From b66c6d1b3dcf9b5a94755e912d7efb4fc4d6ddd6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 8 Apr 2024 18:51:46 -0700 Subject: [PATCH] [IR] Mutable IR implementation (#1344) Implement mutation methods on Graph and Function and a DoublyLinkedList to support safe mutation during iteration. Nodes in different graphs can be moved to other graphs safely during iteration. ### Methods implemented ``` - __getitem__ - __len__ - __iter__ - __reversed__ - append - extend - remove - insert_after - insert_before - sort (declared, not implemented) ``` The mutation methods are inspired by the pytorch FX graph methods. ### Safe Iterators The behavior of the iterators is: - If new elements are inserted after the current node, the iterator will iterate over them as well. - If new elements are inserted before the current node, they will not be iterated over in this iteration. - If the current node is lifted and inserted in a different location, iteration will start from the "next" node at the _original_ location. A node cannot be added to a graph if it belongs to another graph. It needs to be removed first. The user is responsible for removing nodes from other graphs before adding the same node to a new graph. ### Linked list implementation The doubly linked list implementation is inspired by the pytorch FX graph: They are both doubly linked lists with a dummy root node. Whereas in FX `graph` contains the root node and `Node`s are the links; we create the `DoublyLinkedOrderedSet` class to decouple the list implementation form the graph. Additionally, we created the `LinkedBox` class as a data struct to hold the prev/next pointers for the Nodes to allow nodes to move across graphs. This is in contrast to FX restriction of all nodes must belong to the same graph in their lifecycle. By allowing the nodes to move to different graphs, we are able to flatten/unflatten graphs without copying. DoublyLinkedOrderedSet is unit tested. ### Next steps 1. Naming authority 2. Graph input/output/initializers mutation --- onnxscript/ir/__init__.py | 4 + onnxscript/ir/_core.py | 294 ++++++++++++++++------ onnxscript/ir/_core_test.py | 4 + onnxscript/ir/_display.py | 4 + onnxscript/ir/_display_test.py | 4 + onnxscript/ir/_enums.py | 4 + onnxscript/ir/_graph_comparison.py | 4 + onnxscript/ir/_invariants.py | 4 + onnxscript/ir/_linked_list.py | 277 +++++++++++++++++++++ onnxscript/ir/_linked_list_test.py | 380 +++++++++++++++++++++++++++++ onnxscript/ir/_metadata.py | 4 + onnxscript/ir/_protocols.py | 86 +++++-- onnxscript/ir/convenience.py | 4 + onnxscript/ir/serde.py | 4 + 14 files changed, 993 insertions(+), 84 deletions(-) create mode 100644 onnxscript/ir/_linked_list.py create mode 100644 onnxscript/ir/_linked_list_test.py diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 8f49989aa..fb090d3b9 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """In-memory intermediate representation for ONNX graphs.""" __all__ = [ diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 95217cd1d..7d2b7977c 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """data structures for the intermediate representation.""" # NOTES for developers: @@ -33,6 +37,7 @@ from onnxscript.ir import ( _display, _enums, + _linked_list, _metadata, _protocols, ) @@ -548,6 +553,10 @@ def __init__( if input_value is not None: input_value.add_user(self, i) + # Add the node to the graph if graph is specified + if self._graph is not None: + self._graph.append(self) + def __str__(self) -> str: node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * ( self._overload != "" @@ -633,18 +642,56 @@ def inputs(self, _: Any) -> None: "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." ) - def replace_input_with(self, index: int, new_input: Value | None) -> None: + def replace_input_with(self, index: int, value: Value | None) -> None: """Replace an input with a new value.""" if index < 0 or index >= len(self.inputs): raise ValueError(f"Index out of range: {index}") old_input = self.inputs[index] self._inputs = tuple( - new_input if i == index else old_input for i, old_input in enumerate(self.inputs) + value if i == index else old_input for i, old_input in enumerate(self.inputs) ) if old_input is not None: old_input.remove_user(self, index) - if new_input is not None: - new_input.add_user(self, index) + if value is not None: + value.add_user(self, index) + + def prepend(self, /, nodes: Node | Iterable[Node]) -> None: + """Insert a node before this node in the list of nodes in the graph. + + It is the same as calling ``graph.insert_before(self, nodes)``. + + Example:: + + Before: previous_node -> self + previous_node' -> node -> next_node' + After: previous_node -> node -> self + previous_node' -> next_node' + + Args: + nodes: A node or a sequence of nodes to put before this node. + """ + if self._graph is None: + raise ValueError("The node to prepend to does not belong to any graph.") + self._graph.insert_before(self, nodes) + + def append(self, /, nodes: Node | Iterable[Node]) -> None: + """Insert a node after this node in the list of nodes in the graph. + + It is the same as calling ``graph.insert_after(self, nodes)``. + + Example:: + + Before: previous_node -> self + previous_node' -> node -> next_node' + After: previous_node -> self -> node + previous_node' -> next_node' + + Args: + nodes: A node or a sequence of nodes to put after this node. + """ + if self._graph is None: + raise ValueError("The node to append to does not belong to any graph.") + self._graph.insert_after(self, nodes) @property def outputs(self) -> Sequence[Value]: @@ -985,6 +1032,8 @@ def __init__( name: str | None = None, ): self.name = name + + # Private fields that are not to be accessed by any other classes self._inputs = tuple(inputs) self._outputs = tuple(outputs) for initializer in initializers: @@ -995,11 +1044,9 @@ def __init__( self._opset_imports = opset_imports or {} self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = None - - # Assign this graph as the owning_graph of all nodes - self._nodes = list(nodes) - for node in self._nodes: - node.graph = self + self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet() + # Call self.extend not self._nodes.extend so the graph reference is added to the nodes + self.extend(nodes) @property def inputs(self) -> tuple[Value, ...]: @@ -1027,11 +1074,104 @@ def opset_imports(self) -> dict[str, int]: @property def nodes(self) -> Sequence[Node]: - return self._nodes + return tuple(self._nodes) - def topologically_sorted_nodes(self) -> Sequence[Node]: + def __getitem__(self, index: int) -> Node: + return self._nodes[index] + + def __len__(self) -> int: + return len(self._nodes) + + def __iter__(self) -> Iterator[Node]: + return iter(self._nodes) + + def __reversed__(self) -> Iterator[Node]: + return reversed(self._nodes) + + def _set_node_graph_to_self(self, node: Node) -> Node: + """Set the graph reference for the node.""" + if node.graph is not None and node.graph is not self: + raise ValueError( + f"The node {node} belongs to another graph. Please remove it first with Graph.remove()." + ) + node._graph = self # pylint: disable=protected-access + return node + + # Mutation methods + def append(self, node: Node, /) -> None: + """Append a node to the graph in O(1) time. + + Args: + node: The node to append. + + Raises: + ValueError: If the node belongs to another graph. + """ + self._set_node_graph_to_self(node) + self._nodes.append(node) + + def extend(self, nodes: Iterable[Node], /) -> None: + """Extend the graph with the given nodes in O(#new_nodes) time. + + Args: + nodes: The nodes to extend the graph with. + + Raises: + ValueError: If any node belongs to another graph. + """ + nodes = [self._set_node_graph_to_self(node) for node in nodes] + self._nodes.extend(nodes) + + def remove(self, node: Node, /) -> None: + """Remove a node from the graph in O(1) time. + + Args: + node: The node to remove. + + Raises: + ValueError: If the node does not belong to this graph. + """ + if node.graph is not self: + raise ValueError(f"The node {node} does not belong to this graph.") + node._graph = None # pylint: disable=protected-access + self._nodes.remove(node) + + def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: + """Insert new nodes after the given node in O(#new_nodes) time. + + Args: + node: The node to insert after. + new_nodes: The new nodes to insert. + + Raises: + ValueError: If any node belongs to another graph. + """ + if isinstance(new_nodes, Node): + new_nodes = (new_nodes,) + new_nodes = [self._set_node_graph_to_self(node) for node in new_nodes] + self._nodes.insert_after(node, new_nodes) + + def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: + """Insert new nodes before the given node in O(#new_nodes) time. + + Args: + node: The node to insert before. + new_nodes: The new nodes to insert. + + Raises: + ValueError: If any node belongs to another graph. + """ + if isinstance(new_nodes, Node): + new_nodes = (new_nodes,) + new_nodes = [self._set_node_graph_to_self(node) for node in new_nodes] + self._nodes.insert_before(node, new_nodes) + + def sort(self) -> None: + """Topologically sort the nodes in the graph.""" raise NotImplementedError("Not implemented yet") + # End of mutation methods + @property def meta(self) -> _metadata.MetadataStore: """The metadata store for intermediate analysis. @@ -1049,15 +1189,6 @@ def metadata_props(self) -> dict[str, str]: self._metadata_props = {} return self._metadata_props - def __getitem__(self, index: int) -> Node: - return self._nodes[index] - - def __len__(self) -> int: - return len(self._nodes) - - def __iter__(self) -> Iterator[Node]: - return iter(self._nodes) - def __str__(self) -> str: # TODO(justinchuby): Show docstrings and metadata inputs_text = "\n" + ",\n".join(str(x) for x in self.inputs) @@ -1155,7 +1286,7 @@ def __init__( doc_string: str | None = None, functions: Sequence[Function] = (), ) -> None: - self.graph: Graph = graph + self.graph: Graph = graph # type: ignore[assignment] self.ir_version = ir_version self.producer_name = producer_name self.producer_version = producer_version @@ -1246,53 +1377,6 @@ def __init__( self._metadata: _metadata.MetadataStore | None = None self._metadata_props: dict[str, str] | None = None - def __str__(self) -> str: - full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "") - inputs_text = ",\n".join(str(x) for x in self.inputs) - outputs_text = ",\n".join(str(x) for x in self.outputs) - attributes_text = ",\n".join( - attr.name + f": {attr.type}" + f"= {attr.value}" * (attr.value is None) - for attr in self.attributes.values() - ) - if attributes_text: - attributes_text = ( - "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}" - ) - signature = f"""\ -< - opset_imports={self.opset_imports!r}, -> -def {full_name}( - inputs=( -{textwrap.indent(inputs_text, ' '*8)} - ),{textwrap.indent(attributes_text, ' '*4)} - outputs=( -{textwrap.indent(outputs_text, ' '*8)} - ), -)""" - node_count = len(self.nodes) - number_width = len(str(node_count)) - node_lines = [] - for i, node in enumerate(self.nodes): - node_name = node.name if node.name else f":anonymous_node:{id(node)}" - node_text = f"# {node_name}\n{node}" - indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) - # Remove the leading spaces - indented_node_text = indented_node_text.strip() - node_lines.append(f"{i:>{number_width}} | {indented_node_text}") - returns = ", ".join(str(x) for x in self.outputs) - body = ( - "{\n" - + textwrap.indent("\n".join(node_lines), " " * 4) - + textwrap.indent(f"\nreturn {returns}", " " * 4) - + "\n}" - ) - - return f"{signature} {body}" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})" - def identifier(self) -> _protocols.OperatorIdentifier: return self.domain, self.name, self.overload @@ -1364,6 +1448,80 @@ def metadata_props(self) -> dict[str, str]: self._metadata_props = {} return self._metadata_props + # Mutation methods + def append(self, node: Node, /) -> None: + """Append a node to the function in O(1) time.""" + self._graph.append(node) + + def extend(self, nodes: Iterable[Node], /) -> None: + """Extend the function with the given nodes in O(#new_nodes) time.""" + self._graph.extend(nodes) + + def remove(self, node: Node, /) -> None: + """Remove a node from the function in O(1) time.""" + self._graph.remove(node) + + def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None: + """Insert new nodes after the given node in O(#new_nodes) time.""" + self._graph.insert_after(node, new_nodes) + + def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None: + """Insert new nodes before the given node in O(#new_nodes) time.""" + self._graph.insert_before(node, new_nodes) + + def sort(self) -> None: + """Topologically sort the nodes in the function.""" + self._graph.sort() + + # End of mutation methods + + def __str__(self) -> str: + full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "") + inputs_text = ",\n".join(str(x) for x in self.inputs) + outputs_text = ",\n".join(str(x) for x in self.outputs) + attributes_text = ",\n".join( + attr.name + f": {attr.type}" + f"= {attr.value}" * (attr.value is None) + for attr in self.attributes.values() + ) + if attributes_text: + attributes_text = ( + "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}" + ) + signature = f"""\ +< + opset_imports={self.opset_imports!r}, +> +def {full_name}( + inputs=( +{textwrap.indent(inputs_text, ' '*8)} + ),{textwrap.indent(attributes_text, ' '*4)} + outputs=( +{textwrap.indent(outputs_text, ' '*8)} + ), +)""" + node_count = len(self.nodes) + number_width = len(str(node_count)) + node_lines = [] + for i, node in enumerate(self.nodes): + node_name = node.name if node.name else f":anonymous_node:{id(node)}" + node_text = f"# {node_name}\n{node}" + indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) + # Remove the leading spaces + indented_node_text = indented_node_text.strip() + node_lines.append(f"{i:>{number_width}} | {indented_node_text}") + returns = ", ".join(str(x) for x in self.outputs) + body = ( + "{\n" + + textwrap.indent("\n".join(node_lines), " " * 4) + + textwrap.indent(f"\nreturn {returns}", " " * 4) + + "\n}" + ) + + return f"{signature} {body}" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})" + class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable): """Reference attribute.""" diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index fd7391321..972e20f9b 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- from __future__ import annotations import pathlib diff --git a/onnxscript/ir/_display.py b/onnxscript/ir/_display.py index 2269506d9..937af9299 100644 --- a/onnxscript/ir/_display.py +++ b/onnxscript/ir/_display.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Internal utilities for displaying the intermediate representation of a model. NOTE: All third-party imports should be scoped and imported only when used to avoid diff --git a/onnxscript/ir/_display_test.py b/onnxscript/ir/_display_test.py index f32595d52..b334237eb 100644 --- a/onnxscript/ir/_display_test.py +++ b/onnxscript/ir/_display_test.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Test display() methods in various classes.""" import contextlib diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index 96ae0b478..2c4c87391 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """ONNX IR enums that matches the ONNX spec.""" from __future__ import annotations diff --git a/onnxscript/ir/_graph_comparison.py b/onnxscript/ir/_graph_comparison.py index 239a5f228..788b4b4d5 100644 --- a/onnxscript/ir/_graph_comparison.py +++ b/onnxscript/ir/_graph_comparison.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Utilities for comparing IR graphs.""" from __future__ import annotations diff --git a/onnxscript/ir/_invariants.py b/onnxscript/ir/_invariants.py index 343fb79e6..8d009c3cc 100644 --- a/onnxscript/ir/_invariants.py +++ b/onnxscript/ir/_invariants.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Utilities to enforce invariants on the IR.""" from __future__ import annotations diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py new file mode 100644 index 000000000..de38c25f4 --- /dev/null +++ b/onnxscript/ir/_linked_list.py @@ -0,0 +1,277 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Mutable list for nodes in a graph with safe mutation properties.""" + +from __future__ import annotations + +from typing import Generic, Iterable, Iterator, Sequence, TypeVar + +T = TypeVar("T") + + +class _LinkBox(Generic[T]): + """A link in a doubly linked list that has a reference to the actual object in the link. + + The :class:`_LinkBox` is a container for the actual object in the list. It is used to + maintain the links between the elements in the linked list. The actual object is stored in the + :attr:`value` attribute. + + By using a separate container for the actual object, we can safely remove the object from the + list without losing the links. This allows us to remove the object from the list during + iteration and place the object into a different list without breaking any chains. + + This is an internal class and should only be initialized by the :class:`DoublyLinkedSet`. + + Attributes: + prev: The previous box in the list. + next: The next box in the list. + erased: A flag to indicate if the box has been removed from the list. + owning_list: The :class:`DoublyLinkedSet` to which the box belongs. + value: The actual object in the list. + """ + + __slots__ = ("prev", "next", "value", "owning_list") + + def __init__(self, owner: DoublyLinkedSet[T], value: T | None) -> None: + """Create a new link box. + + Args: + owner: The linked list to which this box belongs. + value: The value to be stored in the link box. When the value is None, + the link box is considered erased (default). The root box of the list + should be created with a None value. + """ + self.prev: _LinkBox[T] = self + self.next: _LinkBox[T] = self + self.value: T | None = value + self.owning_list: DoublyLinkedSet[T] = owner + + @property + def erased(self) -> bool: + return self.value is None + + def erase(self) -> None: + """Remove the link from the list and detach the value from the box.""" + if self.value is None: + raise ValueError("_LinkBox is already erased") + # Update the links + prev, next_ = self.prev, self.next + prev.next, next_.prev = next_, prev + # Detach the value + self.value = None + + def __repr__(self) -> str: + return f"_LinkBox({self.value!r}, erased={self.erased}, prev={self.prev.value!r}, next={self.next.value!r})" + + +class DoublyLinkedSet(Generic[T], Sequence[T]): + """A doubly linked ordered set of nodes. + + The container can be viewed as a set as it does not allow duplicate values. The order of the + elements is maintained. One can typically treat it as a doubly linked list with list-like + methods implemented. + + Adding and removing elements from the set during iteration is safe. Moving elements + from one set to another is also safe. + + During the iteration: + - If new elements are inserted after the current node, the iterator will + iterate over them as well. + - If new elements are inserted before the current node, they will + not be iterated over in this iteration. + - If the current node is lifted and inserted in a different location, + iteration will start from the "next" node at the _original_ location. + + Time complexity: + Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n), + although accessing nodes at either end of the set is O(1). I.e. + ``linked_set[0]`` and ``linked_set[-1]`` are O(1). + + Values need to be hashable. ``None`` is not a valid value in the set. + """ + + __slots__ = ("_root", "_length", "_value_ids_to_boxes") + + def __init__(self, values: Iterable[T] | None = None) -> None: + # Using the root node simplifies the mutation implementation a lot + root_ = _LinkBox(self, None) + self._root: _LinkBox = root_ + self._length = 0 + self._value_ids_to_boxes: dict[int, _LinkBox] = {} + if values is not None: + self.extend(values) + + def __iter__(self) -> Iterator[T]: + """Iterate over the elements in the list. + + - If new elements are inserted after the current node, the iterator will + iterate over them as well. + - If new elements are inserted before the current node, they will + not be iterated over in this iteration. + - If the current node is lifted and inserted in a different location, + iteration will start from the "next" node at the _original_ location. + """ + box = self._root.next + while box is not self._root: + if box.owning_list is not self: + raise RuntimeError(f"Element {box!r} is not in the list") + if not box.erased: + assert box.value is not None + yield box.value + box = box.next + + def __reversed__(self) -> Iterator[T]: + """Iterate over the elements in the list in reverse order.""" + box = self._root.prev + while box is not self._root: + if not box.erased: + assert box.value is not None + yield box.value + box = box.prev + + def __len__(self) -> int: + assert self._length == len( + self._value_ids_to_boxes + ), "Bug in the implementation: length mismatch" + return self._length + + def __getitem__(self, index: int) -> T: + """Get the node at the given index. + + Complexity is O(n). + """ + if index >= self._length or index < -self._length: + raise IndexError( + f"Index out of range: {index} not in range [-{self._length}, {self._length})" + ) + if index < 0: + # Look up from the end of the list + iterator = reversed(self) + item = next(iterator) + for _ in range(-index - 1): + item = next(iterator) + else: + iterator = iter(self) # type: ignore[assignment] + item = next(iterator) + for _ in range(index): + item = next(iterator) + return item + + def _insert_one_after( + self, + box: _LinkBox[T], + new_value: T, + ) -> _LinkBox[T]: + """Insert a new value after the given box. + + All insertion methods should call this method to ensure that the list is updated correctly. + + Example:: + Before: A <-> B <-> C + ^v0 ^v1 ^v2 + Call: _insert_one_after(B, v3) + After: A <-> B <-> new_box <-> C + ^v0 ^v1 ^v3 ^v2 + + Args: + box: The box which the new value is to be inserted. + new_value: The new value to be inserted. + """ + if new_value is None: + raise TypeError(f"{self.__class__.__name__} does not support None values") + if box.value is new_value: + # Do nothing if the new value is the same as the old value + return box + if box.owning_list is not self: + raise ValueError(f"Value {box.value!r} is not in the list") + + if (new_value_id := id(new_value)) in self._value_ids_to_boxes: + # If the value is already in the list, remove it first + self.remove(new_value) + + # Create a new _LinkBox for the new value + new_box = _LinkBox(self, new_value) + # original_box <=> original_next + # becomes + # original_box <=> new_box <=> original_next + original_next = box.next + box.next = new_box + new_box.prev = box + new_box.next = original_next + original_next.prev = new_box + + # Be sure to update the length and mapping + self._length += 1 + self._value_ids_to_boxes[new_value_id] = new_box + + return new_box + + def _insert_many_after( + self, + box: _LinkBox[T], + new_values: Iterable[T], + ): + """Insert multiple new values after the given box.""" + insertion_point = box + for new_value in new_values: + insertion_point = self._insert_one_after(insertion_point, new_value) + + def remove(self, value: T) -> None: + """Remove a node from the list.""" + if (value_id := id(value)) not in self._value_ids_to_boxes: + raise ValueError(f"Value {value!r} is not in the list") + box = self._value_ids_to_boxes[value_id] + # Remove the link box and detach the value from the box + box.erase() + + # Be sure to update the length and mapping + self._length -= 1 + del self._value_ids_to_boxes[value_id] + + def append(self, value: T) -> None: + """Append a node to the list.""" + _ = self._insert_one_after(self._root.prev, value) + + def extend( + self, + values: Iterable[T], + ) -> None: + for value in values: + self.append(value) + + def insert_after( + self, + value: T, + new_values: Iterable[T], + ) -> None: + """Insert new nodes after the given node. + + Args: + value: The value after which the new values are to be inserted. + new_values: The new values to be inserted. + """ + if (value_id := id(value)) not in self._value_ids_to_boxes: + raise ValueError(f"Value {value!r} is not in the list") + insertion_point = self._value_ids_to_boxes[value_id] + return self._insert_many_after(insertion_point, new_values) + + def insert_before( + self, + value: T, + new_values: Iterable[T], + ) -> None: + """Insert new nodes before the given node. + + Args: + value: The value before which the new values are to be inserted. + new_values: The new values to be inserted. + """ + if (value_id := id(value)) not in self._value_ids_to_boxes: + raise ValueError(f"Value {value!r} is not in the list") + insertion_point = self._value_ids_to_boxes[value_id].prev + return self._insert_many_after(insertion_point, new_values) + + def __repr__(self) -> str: + return f"DoublyLinkedSet({list(self)})" diff --git a/onnxscript/ir/_linked_list_test.py b/onnxscript/ir/_linked_list_test.py new file mode 100644 index 000000000..a82b0e172 --- /dev/null +++ b/onnxscript/ir/_linked_list_test.py @@ -0,0 +1,380 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the _linked_list module.""" + +from __future__ import annotations + +import unittest + +import parameterized + +from onnxscript.ir import _linked_list + + +class _TestElement: + def __init__(self, value): + self.value = value + + def __repr__(self) -> str: + return f"_TestElement({self.value})" + + +class DoublyLinkedSetTest(unittest.TestCase): + def test_empty_list(self): + linked_list = _linked_list.DoublyLinkedSet() + self.assertEqual(len(linked_list), 0) + self.assertEqual(list(linked_list), []) + self.assertEqual(list(reversed(linked_list)), []) + with self.assertRaises(IndexError): + _ = linked_list[0] + with self.assertRaises(IndexError): + _ = linked_list[-1] + + def test_append_single_element(self): + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + + self.assertEqual(len(linked_list), 1) + self.assertEqual(linked_list[0], elem) + self.assertEqual(linked_list[-1], elem) + self.assertEqual(list(linked_list), [elem]) + self.assertEqual(list(reversed(linked_list)), [elem]) + with self.assertRaises(IndexError): + _ = linked_list[1] + with self.assertRaises(IndexError): + _ = linked_list[-2] + + def test_append_multiple_elements(self): + linked_list = _linked_list.DoublyLinkedSet() + elems = [_TestElement(i) for i in range(3)] + for elem in elems: + linked_list.append(elem) + + self.assertEqual(len(linked_list), 3) + self.assertEqual(linked_list[0], elems[0]) + self.assertEqual(linked_list[1], elems[1]) + self.assertEqual(linked_list[2], elems[2]) + self.assertEqual(linked_list[-1], elems[2]) + self.assertEqual(linked_list[-2], elems[1]) + self.assertEqual(linked_list[-3], elems[0]) + self.assertEqual(list(linked_list), elems) + self.assertEqual(list(reversed(linked_list)), list(reversed(elems))) + + def test_extend(self): + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + self.assertEqual(len(linked_list), 3) + self.assertEqual(linked_list[0], elems[0]) + self.assertEqual(linked_list[1], elems[1]) + self.assertEqual(linked_list[2], elems[2]) + self.assertEqual(linked_list[-1], elems[2]) + self.assertEqual(linked_list[-2], elems[1]) + self.assertEqual(linked_list[-3], elems[0]) + self.assertEqual(list(linked_list), elems) + self.assertEqual(list(reversed(linked_list)), list(reversed(elems))) + + @parameterized.parameterized.expand( + [ + ("single_element", [0], 0, [1], [0, 1]), + ("single_element_negative_index", [0], -1, [1], [0, 1]), + ("multiple_elements", [0], 0, [1, 2], [0, 1, 2]), + ("multiple_elements_negative_index", [0], -1, [1, 2], [0, 1, 2]), + ( + "multiple_original_elements_insert_at_start", + [0, 1, 2], + 0, + [42, 43], + [0, 42, 43, 1, 2], + ), + ( + "multiple_original_elements_insert_at_middle", + [0, 1, 2], + 1, + [42, 43], + [0, 1, 42, 43, 2], + ), + ( + "multiple_original_elements_insert_at_end", + [0, 1, 2], + 2, + [42, 43], + [0, 1, 2, 42, 43], + ), + ] + ) + def test_insert_after( + self, _: str, original: list[int], location: int, insertion: list[int], expected: list + ) -> None: + # Construct the original list + elems = [_TestElement(i) for i in original] + linked_list = _linked_list.DoublyLinkedSet(elems) + + # Create the new elements + new_elements = [_TestElement(i) for i in insertion] + linked_list.insert_after(elems[location], new_elements) + + # Check the list + self.assertEqual(len(linked_list), len(expected)) + self.assertEqual([elem.value for elem in linked_list], expected) + + @parameterized.parameterized.expand( + [ + ("single_element", [0], 0, [1], [1, 0]), + ("single_element_negative_index", [0], -1, [1], [1, 0]), + ("multiple_elements", [0], 0, [1, 3], [1, 3, 0]), + ("multiple_elements_negative_index", [0], -1, [1, 3], [1, 3, 0]), + ( + "multiple_original_elements_insert_at_start", + [0, 1, 2], + 0, + [42, 43], + [42, 43, 0, 1, 2], + ), + ( + "multiple_original_elements_insert_at_middle", + [0, 1, 2], + 1, + [42, 43], + [0, 42, 43, 1, 2], + ), + ( + "multiple_original_elements_insert_at_end", + [0, 1, 2], + 2, + [42, 43], + [0, 1, 42, 43, 2], + ), + ] + ) + def test_insert_before( + self, _: str, original: list[int], location: int, insertion: list[int], expected: list + ) -> None: + # Construct the original list + elems = [_TestElement(i) for i in original] + linked_list = _linked_list.DoublyLinkedSet(elems) + + # Create the new elements + new_elements = [_TestElement(i) for i in insertion] + linked_list.insert_before(elems[location], new_elements) + + # Check the list + self.assertEqual(len(linked_list), len(expected)) + self.assertEqual([elem.value for elem in linked_list], expected) + self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1]) + + @parameterized.parameterized.expand( + [ + ("start", 0, [1, 2]), + ("middle", 1, [0, 2]), + ("end", 2, [0, 1]), + ("start_negative", -1, [0, 1]), + ("middle_negative", -2, [0, 2]), + ("end_negative", -3, [1, 2]), + ] + ) + def test_remove(self, _: str, index: int, expected: list[int]) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + + linked_list.remove(elems[index]) + + self.assertEqual(len(linked_list), 2) + self.assertEqual([elem.value for elem in linked_list], expected) + self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1]) + + def test_remove_raises_when_element_not_found(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + + with self.assertRaises(ValueError): + linked_list.remove(_TestElement(3)) + + def test_remove_raises_when_element_is_already_removed(self) -> None: + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + linked_list.remove(elem) + + with self.assertRaises(ValueError): + linked_list.remove(elem) + + def test_append_self_does_nothing(self) -> None: + linked_list = _linked_list.DoublyLinkedSet() + elem = _TestElement(0) + linked_list.append(elem) + + linked_list.append(elem) + + self.assertEqual(len(linked_list), 1) + self.assertEqual(linked_list[0], elem) + self.assertEqual(list(linked_list), [elem]) + self.assertEqual(list(reversed(linked_list)), [elem]) + + def test_append_supports_appending_element_from_the_same_list(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + + linked_list.append(elems[1]) + + self.assertEqual(len(linked_list), 3) + self.assertEqual([elem.value for elem in linked_list], [0, 2, 1]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [1, 2, 0]) + + def test_extend_supports_extending_elements_from_the_same_list(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + linked_list.extend(elems[::-1]) + + self.assertEqual(len(linked_list), 3) + self.assertEqual([elem.value for elem in linked_list], [2, 1, 0]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [0, 1, 2]) + + def test_insert_after_supports_inserting_element_from_the_same_list(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + linked_list.insert_after(elems[0], [elems[2]]) + + self.assertEqual(len(linked_list), 3) + self.assertEqual([elem.value for elem in linked_list], [0, 2, 1]) + + def test_insert_before_supports_inserting_element_from_the_same_list(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + linked_list.insert_before(elems[0], [elems[2]]) + + self.assertEqual(len(linked_list), 3) + self.assertEqual([elem.value for elem in linked_list], [2, 0, 1]) + + def test_iterator_supports_mutation_during_iteration_current_element(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + for elem in linked_list: + if elem.value == 1: + linked_list.remove(elem) + + self.assertEqual(len(linked_list), 2) + self.assertEqual([elem.value for elem in linked_list], [0, 2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0]) + + def test_iterator_supports_mutation_during_iteration_previous_element(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + for elem in linked_list: + if elem.value == 1: + linked_list.remove(elem) + linked_list.remove(elems[0]) + + self.assertEqual(len(linked_list), 1) + self.assertEqual([elem.value for elem in linked_list], [2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2]) + + def test_iterator_supports_mutation_during_iteration_next_element(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + for elem in linked_list: + if elem.value == 1: + linked_list.remove(elems[2]) + linked_list.remove(elem) + + self.assertEqual(len(linked_list), 1) + self.assertEqual([elem.value for elem in linked_list], [0]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [0]) + + def test_iterator_supports_mutation_in_nested_iteration_right_of_iterator(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + iter1_visited = [] + iter2_visited = [] + for elem in linked_list: + iter1_visited.append(elem.value) + for elem2 in linked_list: + iter2_visited.append(elem2.value) + if elem2.value == 1: + linked_list.remove(elem2) + + self.assertEqual(len(linked_list), 2) + self.assertEqual(iter1_visited, [0, 2]) + self.assertEqual(iter2_visited, [0, 1, 2, 0, 2]) + self.assertEqual([elem.value for elem in linked_list], [0, 2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0]) + + def test_iterator_supports_mutation_in_nested_iteration_when_iter_is_self(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + iter1_visited = [] + iter2_visited = [] + for elem in linked_list: + iter1_visited.append(elem.value) + for elem2 in linked_list: + iter2_visited.append(elem2.value) + if elem2.value == 0: # Remove the element the current iterator points to + linked_list.remove(elem2) + + self.assertEqual(len(linked_list), 2) + self.assertEqual(iter1_visited, [0, 1, 2]) + self.assertEqual(iter2_visited, [0, 1, 2, 1, 2, 1, 2]) + self.assertEqual([elem.value for elem in linked_list], [1, 2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1]) + + def test_iterator_supports_mutation_in_nested_iteration_left_of_iterator(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + iter1_visited = [] + iter2_visited = [] + for elem in linked_list: + iter1_visited.append(elem.value) + for elem2 in linked_list: + iter2_visited.append(elem2.value) + if ( + elem.value == 1 and elem2.value == 0 + ): # Remove the element before the current iterator points to + linked_list.remove(elems[0]) + + self.assertEqual(len(linked_list), 2) + self.assertEqual(iter1_visited, [0, 1, 2]) + self.assertEqual(iter2_visited, [0, 1, 2, 0, 1, 2, 1, 2]) + self.assertEqual([elem.value for elem in linked_list], [1, 2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1]) + + def test_insert_after_supports_element_from_different_list_during_iteration(self) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + other_linked_list = _linked_list.DoublyLinkedSet() + other_elem = _TestElement(42) + other_linked_list.append(other_elem) + + for elem in linked_list: + if elem.value == 1: + linked_list.insert_after(elem, [other_elem]) + + self.assertEqual(len(linked_list), 4) + self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0]) + # Other list remains unchanged + self.assertEqual(len(other_linked_list), 1) + self.assertEqual([elem.value for elem in other_linked_list], [42]) + + def test_insert_after_supports_taking_elements_from_another_doubly_linked_list( + self, + ) -> None: + elems = [_TestElement(i) for i in range(3)] + linked_list = _linked_list.DoublyLinkedSet(elems) + other_linked_list = _linked_list.DoublyLinkedSet() + other_elem = _TestElement(42) + other_linked_list.append(other_elem) + + linked_list.insert_after(elems[1], other_linked_list) + + self.assertEqual(len(linked_list), 4) + self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2]) + self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0]) + # Other list remains unchanged + self.assertEqual(len(other_linked_list), 1) + self.assertEqual([elem.value for elem in other_linked_list], [42]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/_metadata.py b/onnxscript/ir/_metadata.py index d7eaac4ef..a29e44712 100644 --- a/onnxscript/ir/_metadata.py +++ b/onnxscript/ir/_metadata.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Class for storing metadata about the IR objects.""" from __future__ import annotations diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index e384ec1cc..a65912183 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -1,21 +1,13 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Protocols for the ONNX IR. This file defines the interfaces for tools to interact with the IR. The interfaces are designed such that tools leveraging the IR can be decoupled from the IR implementation. This allows for the implementation to evolve independently of the tools. - -The file contains two sets of interfaces: -1. Topologically immutable interfaces: - These interfaces provide a complete view of the ONNX model and allows mutation - against any metadata fields like shape, type, and node attributes. However, the - interfaces are topologically immutable, meaning that the structure of the graph - cannot be changed. This is useful for tools that need to analyze the model - without modifying how nodes are connected. -2. Mutable interfaces: - These interfaces provide a mutable view of the ONNX model. They allow for - modification of the graph structure. This is useful for tools that need to - transform the model. """ from __future__ import annotations @@ -24,6 +16,7 @@ from typing import ( AbstractSet, Any, + Iterable, Iterator, Mapping, OrderedDict, @@ -159,6 +152,7 @@ class ValueProtocol(Protocol): shape: ShapeProtocol | None type: TypeProtocol | None metadata_props: Mapping[str, str] + meta: Mapping[str, Any] def users(self) -> AbstractSet[tuple[NodeProtocol, int]]: """The set of (node, input_index) with node being those that use this value as an input.""" @@ -219,6 +213,11 @@ class NodeProtocol(Protocol): version: int | None doc_string: str | None metadata_props: Mapping[str, str] + meta: Mapping[str, Any] + + def replace_input_with(self, index: int, value: ValueProtocol | None) -> None: + """Set the input at the given index to the given value, replacing the original value.""" + ... @typing.runtime_checkable @@ -254,9 +253,36 @@ class GraphProtocol(Protocol): doc_string: str opset_imports: Mapping[str, int] metadata_props: Mapping[str, str] + meta: Mapping[str, Any] + + def __getitem__(self, index: int) -> NodeProtocol: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[NodeProtocol]: ... + def __reversed__(self) -> Iterator[NodeProtocol]: ... + + # Mutation methods + def append(self, node: NodeProtocol, /) -> None: + """Append a node to the graph.""" + ... + + def extend(self, nodes: Iterable[NodeProtocol], /) -> None: + """Extend the graph with the given nodes.""" + ... + + def remove(self, node: NodeProtocol, /) -> None: + """Remove a node from the graph.""" + ... + + def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + """Insert new nodes after the given node.""" + ... - def topologically_sorted_nodes(self) -> Sequence[NodeProtocol]: - """Return the nodes in topological order.""" + def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + """Insert new nodes before the given node.""" + ... + + def sort(self) -> None: + """Topologically sort the nodes in the graph.""" ... @@ -290,6 +316,7 @@ class ModelProtocol(Protocol): # TODO(justinchuby): Add training_info opset_imports: Mapping[str, int] metadata_props: Mapping[str, str] + meta: Mapping[str, Any] @typing.runtime_checkable @@ -451,11 +478,38 @@ class FunctionProtocol(Protocol): opset_imports: Mapping[str, int] nodes: Sequence[NodeProtocol] metadata_props: Mapping[str, str] + meta: Mapping[str, Any] + def __getitem__(self, index: int) -> NodeProtocol: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[NodeProtocol]: ... + def __reversed__(self) -> Iterator[NodeProtocol]: ... def identifier(self) -> OperatorIdentifier: """Return the unique identifier of the function.""" ... - def topologically_sorted_nodes(self) -> Sequence[NodeProtocol]: - """Return the nodes in topological order.""" + # Mutation methods + # End Block + def append(self, node: NodeProtocol, /) -> None: + """Append a node to the function.""" + ... + + def extend(self, nodes: Iterable[NodeProtocol], /) -> None: + """Extend the function with the given nodes.""" + ... + + def remove(self, node: NodeProtocol, /) -> None: + """Remove a node from the function.""" + ... + + def insert_after(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + """Insert new nodes after the given node.""" + ... + + def insert_before(self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol], /) -> None: + """Insert new nodes before the given node.""" + ... + + def sort(self) -> None: + """Topologically sort the nodes in the function.""" ... diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index 708e0c65d..12bb21090 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Convenience methods for constructing (and manipulating?) the IR.""" from __future__ import annotations diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index dc29d826c..d16aaae23 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- """Serialize and deserialize the intermediate representation to/from ONNX protos.""" # NOTES for developers: