Skip to content

Commit

Permalink
[IR] Mutable IR implementation (#1344)
Browse files Browse the repository at this point in the history
Implement mutation methods on Graph and Function and a DoublyLinkedList
to support safe mutation during iteration. Nodes in different graphs can
be moved to other graphs safely during iteration.

### Methods implemented

```
- __getitem__
- __len__
- __iter__
- __reversed__
- append
- extend
- remove
- insert_after
- insert_before
- sort (declared, not implemented)
```

The mutation methods are inspired by the pytorch FX graph methods.

### Safe Iterators

The behavior of the iterators is:

- If new elements are inserted after the current node, the iterator will
iterate over them as well.
- If new elements are inserted before the current node, they will not be
iterated over in this iteration.
- If the current node is lifted and inserted in a different location,
iteration will start from the "next" node at the _original_ location.

A node cannot be added to a graph if it belongs to another graph. It
needs to be removed first. The user is responsible for removing nodes
from other graphs before adding the same node to a new graph.

### Linked list implementation
The doubly linked list implementation is inspired by the pytorch FX
graph: They are both doubly linked lists with a dummy root node. Whereas
in FX `graph` contains the root node and `Node`s are the links; we
create the `DoublyLinkedOrderedSet` class to decouple the list
implementation form the graph. Additionally, we created the `LinkedBox`
class as a data struct to hold the prev/next pointers for the Nodes to
allow nodes to move across graphs. This is in contrast to FX restriction
of all nodes must belong to the same graph in their lifecycle. By
allowing the nodes to move to different graphs, we are able to
flatten/unflatten graphs without copying.

DoublyLinkedOrderedSet is unit tested.

### Next steps

1. Naming authority
2. Graph input/output/initializers mutation
  • Loading branch information
justinchuby authored Apr 9, 2024
1 parent edb68ed commit b66c6d1
Show file tree
Hide file tree
Showing 14 changed files with 993 additions and 84 deletions.
4 changes: 4 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""In-memory intermediate representation for ONNX graphs."""

__all__ = [
Expand Down
294 changes: 226 additions & 68 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""data structures for the intermediate representation."""

# NOTES for developers:
Expand Down Expand Up @@ -33,6 +37,7 @@
from onnxscript.ir import (
_display,
_enums,
_linked_list,
_metadata,
_protocols,
)
Expand Down Expand Up @@ -548,6 +553,10 @@ def __init__(
if input_value is not None:
input_value.add_user(self, i)

# Add the node to the graph if graph is specified
if self._graph is not None:
self._graph.append(self)

def __str__(self) -> str:
node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * (
self._overload != ""
Expand Down Expand Up @@ -633,18 +642,56 @@ def inputs(self, _: Any) -> None:
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def replace_input_with(self, index: int, new_input: Value | None) -> None:
def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
raise ValueError(f"Index out of range: {index}")
old_input = self.inputs[index]
self._inputs = tuple(
new_input if i == index else old_input for i, old_input in enumerate(self.inputs)
value if i == index else old_input for i, old_input in enumerate(self.inputs)
)
if old_input is not None:
old_input.remove_user(self, index)
if new_input is not None:
new_input.add_user(self, index)
if value is not None:
value.add_user(self, index)

def prepend(self, /, nodes: Node | Iterable[Node]) -> None:
"""Insert a node before this node in the list of nodes in the graph.
It is the same as calling ``graph.insert_before(self, nodes)``.
Example::
Before: previous_node -> self
previous_node' -> node -> next_node'
After: previous_node -> node -> self
previous_node' -> next_node'
Args:
nodes: A node or a sequence of nodes to put before this node.
"""
if self._graph is None:
raise ValueError("The node to prepend to does not belong to any graph.")
self._graph.insert_before(self, nodes)

def append(self, /, nodes: Node | Iterable[Node]) -> None:
"""Insert a node after this node in the list of nodes in the graph.
It is the same as calling ``graph.insert_after(self, nodes)``.
Example::
Before: previous_node -> self
previous_node' -> node -> next_node'
After: previous_node -> self -> node
previous_node' -> next_node'
Args:
nodes: A node or a sequence of nodes to put after this node.
"""
if self._graph is None:
raise ValueError("The node to append to does not belong to any graph.")
self._graph.insert_after(self, nodes)

@property
def outputs(self) -> Sequence[Value]:
Expand Down Expand Up @@ -985,6 +1032,8 @@ def __init__(
name: str | None = None,
):
self.name = name

# Private fields that are not to be accessed by any other classes
self._inputs = tuple(inputs)
self._outputs = tuple(outputs)
for initializer in initializers:
Expand All @@ -995,11 +1044,9 @@ def __init__(
self._opset_imports = opset_imports or {}
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = None

# Assign this graph as the owning_graph of all nodes
self._nodes = list(nodes)
for node in self._nodes:
node.graph = self
self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

@property
def inputs(self) -> tuple[Value, ...]:
Expand Down Expand Up @@ -1027,11 +1074,104 @@ def opset_imports(self) -> dict[str, int]:

@property
def nodes(self) -> Sequence[Node]:
return self._nodes
return tuple(self._nodes)

def topologically_sorted_nodes(self) -> Sequence[Node]:
def __getitem__(self, index: int) -> Node:
return self._nodes[index]

def __len__(self) -> int:
return len(self._nodes)

def __iter__(self) -> Iterator[Node]:
return iter(self._nodes)

def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_node_graph_to_self(self, node: Node) -> Node:
"""Set the graph reference for the node."""
if node.graph is not None and node.graph is not self:
raise ValueError(
f"The node {node} belongs to another graph. Please remove it first with Graph.remove()."
)
node._graph = self # pylint: disable=protected-access
return node

# Mutation methods
def append(self, node: Node, /) -> None:
"""Append a node to the graph in O(1) time.
Args:
node: The node to append.
Raises:
ValueError: If the node belongs to another graph.
"""
self._set_node_graph_to_self(node)
self._nodes.append(node)

def extend(self, nodes: Iterable[Node], /) -> None:
"""Extend the graph with the given nodes in O(#new_nodes) time.
Args:
nodes: The nodes to extend the graph with.
Raises:
ValueError: If any node belongs to another graph.
"""
nodes = [self._set_node_graph_to_self(node) for node in nodes]
self._nodes.extend(nodes)

def remove(self, node: Node, /) -> None:
"""Remove a node from the graph in O(1) time.
Args:
node: The node to remove.
Raises:
ValueError: If the node does not belong to this graph.
"""
if node.graph is not self:
raise ValueError(f"The node {node} does not belong to this graph.")
node._graph = None # pylint: disable=protected-access
self._nodes.remove(node)

def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
"""Insert new nodes after the given node in O(#new_nodes) time.
Args:
node: The node to insert after.
new_nodes: The new nodes to insert.
Raises:
ValueError: If any node belongs to another graph.
"""
if isinstance(new_nodes, Node):
new_nodes = (new_nodes,)
new_nodes = [self._set_node_graph_to_self(node) for node in new_nodes]
self._nodes.insert_after(node, new_nodes)

def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
"""Insert new nodes before the given node in O(#new_nodes) time.
Args:
node: The node to insert before.
new_nodes: The new nodes to insert.
Raises:
ValueError: If any node belongs to another graph.
"""
if isinstance(new_nodes, Node):
new_nodes = (new_nodes,)
new_nodes = [self._set_node_graph_to_self(node) for node in new_nodes]
self._nodes.insert_before(node, new_nodes)

def sort(self) -> None:
"""Topologically sort the nodes in the graph."""
raise NotImplementedError("Not implemented yet")

# End of mutation methods

@property
def meta(self) -> _metadata.MetadataStore:
"""The metadata store for intermediate analysis.
Expand All @@ -1049,15 +1189,6 @@ def metadata_props(self) -> dict[str, str]:
self._metadata_props = {}
return self._metadata_props

def __getitem__(self, index: int) -> Node:
return self._nodes[index]

def __len__(self) -> int:
return len(self._nodes)

def __iter__(self) -> Iterator[Node]:
return iter(self._nodes)

def __str__(self) -> str:
# TODO(justinchuby): Show docstrings and metadata
inputs_text = "\n" + ",\n".join(str(x) for x in self.inputs)
Expand Down Expand Up @@ -1155,7 +1286,7 @@ def __init__(
doc_string: str | None = None,
functions: Sequence[Function] = (),
) -> None:
self.graph: Graph = graph
self.graph: Graph = graph # type: ignore[assignment]
self.ir_version = ir_version
self.producer_name = producer_name
self.producer_version = producer_version
Expand Down Expand Up @@ -1246,53 +1377,6 @@ def __init__(
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = None

def __str__(self) -> str:
full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "")
inputs_text = ",\n".join(str(x) for x in self.inputs)
outputs_text = ",\n".join(str(x) for x in self.outputs)
attributes_text = ",\n".join(
attr.name + f": {attr.type}" + f"= {attr.value}" * (attr.value is None)
for attr in self.attributes.values()
)
if attributes_text:
attributes_text = (
"\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}"
)
signature = f"""\
<
opset_imports={self.opset_imports!r},
>
def {full_name}(
inputs=(
{textwrap.indent(inputs_text, ' '*8)}
),{textwrap.indent(attributes_text, ' '*4)}
outputs=(
{textwrap.indent(outputs_text, ' '*8)}
),
)"""
node_count = len(self.nodes)
number_width = len(str(node_count))
node_lines = []
for i, node in enumerate(self.nodes):
node_name = node.name if node.name else f":anonymous_node:{id(node)}"
node_text = f"# {node_name}\n{node}"
indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
# Remove the leading spaces
indented_node_text = indented_node_text.strip()
node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
returns = ", ".join(str(x) for x in self.outputs)
body = (
"{\n"
+ textwrap.indent("\n".join(node_lines), " " * 4)
+ textwrap.indent(f"\nreturn {returns}", " " * 4)
+ "\n}"
)

return f"{signature} {body}"

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"

def identifier(self) -> _protocols.OperatorIdentifier:
return self.domain, self.name, self.overload

Expand Down Expand Up @@ -1364,6 +1448,80 @@ def metadata_props(self) -> dict[str, str]:
self._metadata_props = {}
return self._metadata_props

# Mutation methods
def append(self, node: Node, /) -> None:
"""Append a node to the function in O(1) time."""
self._graph.append(node)

def extend(self, nodes: Iterable[Node], /) -> None:
"""Extend the function with the given nodes in O(#new_nodes) time."""
self._graph.extend(nodes)

def remove(self, node: Node, /) -> None:
"""Remove a node from the function in O(1) time."""
self._graph.remove(node)

def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
"""Insert new nodes after the given node in O(#new_nodes) time."""
self._graph.insert_after(node, new_nodes)

def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
"""Insert new nodes before the given node in O(#new_nodes) time."""
self._graph.insert_before(node, new_nodes)

def sort(self) -> None:
"""Topologically sort the nodes in the function."""
self._graph.sort()

# End of mutation methods

def __str__(self) -> str:
full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "")
inputs_text = ",\n".join(str(x) for x in self.inputs)
outputs_text = ",\n".join(str(x) for x in self.outputs)
attributes_text = ",\n".join(
attr.name + f": {attr.type}" + f"= {attr.value}" * (attr.value is None)
for attr in self.attributes.values()
)
if attributes_text:
attributes_text = (
"\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}"
)
signature = f"""\
<
opset_imports={self.opset_imports!r},
>
def {full_name}(
inputs=(
{textwrap.indent(inputs_text, ' '*8)}
),{textwrap.indent(attributes_text, ' '*4)}
outputs=(
{textwrap.indent(outputs_text, ' '*8)}
),
)"""
node_count = len(self.nodes)
number_width = len(str(node_count))
node_lines = []
for i, node in enumerate(self.nodes):
node_name = node.name if node.name else f":anonymous_node:{id(node)}"
node_text = f"# {node_name}\n{node}"
indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
# Remove the leading spaces
indented_node_text = indented_node_text.strip()
node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
returns = ", ".join(str(x) for x in self.outputs)
body = (
"{\n"
+ textwrap.indent("\n".join(node_lines), " " * 4)
+ textwrap.indent(f"\nreturn {returns}", " " * 4)
+ "\n}"
)

return f"{signature} {body}"

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"


class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
"""Reference attribute."""
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

import pathlib
Expand Down
Loading

0 comments on commit b66c6d1

Please sign in to comment.