Skip to content

Commit

Permalink
[IR] Make add_user in Value private (#1370)
Browse files Browse the repository at this point in the history
Make `add_user` in Value private and note that they are only used by the
`Node` class.
  • Loading branch information
justinchuby authored Apr 11, 2024
1 parent e3bbfaa commit f3c1ab4
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def __init__(
# Add the node as a user of the inputs
for i, input_value in enumerate(self._inputs):
if input_value is not None:
input_value.add_user(self, i)
input_value._add_user(self, i) # pylint: disable=protected-access

# Add the node to the graph if graph is specified
if self._graph is not None:
Expand Down Expand Up @@ -705,9 +705,9 @@ def replace_input_with(self, index: int, value: Value | None) -> None:
value if i == index else old_input for i, old_input in enumerate(self.inputs)
)
if old_input is not None:
old_input.remove_user(self, index)
old_input._remove_user(self, index) # pylint: disable=protected-access
if value is not None:
value.add_user(self, index)
value._add_user(self, index) # pylint: disable=protected-access

def prepend(self, /, nodes: Node | Iterable[Node]) -> None:
"""Insert a node before this node in the list of nodes in the graph.
Expand Down Expand Up @@ -951,13 +951,25 @@ def def_index(self) -> int | None:
return self._def_index

def users(self) -> frozenset[tuple[Node, int]]:
"""Return a set of users of the value.
The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the user is ``(node, 1)``.
"""
return frozenset(self._users)

def add_user(self, user: Node, index: int) -> None:
def _add_user(self, user: Node, index: int) -> None:
"""Add a user node.
This is an internal method. It should only be called by the Node class.
"""
self._users.add((user, index))

def remove_user(self, user: Node, index: int) -> None:
"""Reduce a user node."""
def _remove_user(self, user: Node, index: int) -> None:
"""Remove a node from the users of this value.
This is an internal method. It should only be called by the Node class.
"""
self._users.remove((user, index))

@property
Expand Down

0 comments on commit f3c1ab4

Please sign in to comment.