Skip to content

Commit

Permalink
[IR] Implement register_initializer (#1941)
Browse files Browse the repository at this point in the history
Implement `register_initializer(value)` on `Graph` for robustly adding
an initializer to the graph. Users can also directly modify the
`graph.initializers` dictionary, but this method does more comprehensive
checks before adding, and calling this method is simpler than modifying
the dictionary.

The method could be in the `convenience` too, but I put it here due to
its relevance and for better discoverability.
  • Loading branch information
justinchuby authored Nov 13, 2024
1 parent d36184f commit fa7d13a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
36 changes: 36 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,42 @@ def outputs(self) -> list[Value]:
def initializers(self) -> dict[str, Value]:
return self._initializers

def register_initializer(self, value: Value) -> None:
"""Register an initializer to the graph.
This is a convenience method to register an initializer to the graph with
checks.
Args:
value: The :class:`Value` to register as an initializer of the graph.
It must have its ``.const_value`` set.
Raises:
ValueError: If a value of the same name that is not this value
is already registered.
ValueError: If the value does not have a name.
ValueError: If the initializer is produced by a node.
ValueError: If the value does not have its ``.const_value`` set.
"""
if value.name in self._initializers:
if self._initializers[value.name] is not value:
raise ValueError(
f"Initializer '{value.name}' is already registered, but"
" it is not the same object: existing={self._initializers[value.name]!r},"
f" new={value!r}"
)
if not value.name:
raise ValueError(f"Initializer must have a name: {value!r}")
if value.producer() is not None:
raise ValueError(
f"Value '{value!r}' is produced by a node and cannot be an initializer."
)
if value.const_value is None:
raise ValueError(
f"Value '{value!r}' must have its const_value set to be an initializer."
)
self._initializers[value.name] = value

@property
def doc_string(self) -> str | None:
return self._doc_string
Expand Down
24 changes: 24 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,30 @@ def test_remove_safe_removes_uses_of_removed_nodes(self):
self.assertEqual(tuple(graph), (sub_node, identity_node))
self.assertEqual(add_node.inputs, (None, None))

def test_register_initializer(self):
self.v1.const_value = ir.tensor([1, 2, 3])
self.graph.register_initializer(self.v1)
self.assertEqual(self.graph.initializers, {self.v1.name: self.v1})

def test_register_initializer_raises_when_value_is_not_constant(self):
with self.assertRaises(ValueError):
self.graph.register_initializer(self.v0)

def test_register_initializer_raises_when_a_different_value_is_already_registered(self):
self.v1.const_value = ir.tensor([1, 2, 3])
self.graph.register_initializer(self.v1)
# This is fine
self.graph.register_initializer(self.v1)
self.v0.name = "v1"
with self.assertRaisesRegex(ValueError, "already registered"):
# Registering a different value with the same name should raise
self.graph.register_initializer(self.v0)

def test_register_initializer_raises_when_value_does_not_have_a_name(self):
self.v1.name = None
with self.assertRaises(ValueError):
self.graph.register_initializer(self.v1)

# TODO(justinchuby): Test graph mutation methods

# Test topological sort.
Expand Down

0 comments on commit fa7d13a

Please sign in to comment.