Skip to content

Commit

Permalink
[IR] Support specifying output values in Node init (#1507)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1509
* #1508
* __->__ #1507

This way users can pre-initialize output values, sometimes elsewhere, or
prefill type/shape information before initializing the node.
  • Loading branch information
justinchuby authored May 7, 2024
1 parent c8cd684 commit 1e4b585
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
64 changes: 59 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,8 @@ def __init__(
attributes: Iterable[Attr | RefAttr] = (),
*,
overload: str = "",
num_outputs: int = 1,
num_outputs: int | None = None,
outputs: Sequence[Value] | None = None,
version: int | None = None,
graph: Graph | None = None,
name: str | None = None,
Expand All @@ -916,13 +917,20 @@ def __init__(
inputs: The input values. When an input is None, it is an empty input.
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
overload: The overload name when the node is invoking a function.
num_outputs: The number of outputs of the node.
num_outputs: The number of outputs of the node. If not specified, the number is 1.
outputs: The output values. If None, the outputs are created during initialization.
version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
graph: The graph that the node belongs to. If None, the node is not added to any graph.
A `Node` must belong to zero or one graph.
name: The name of the node. If None, the node is anonymous.
doc_string: The documentation string.
metadata_props: The metadata properties.
Raises:
TypeError: If the attributes are not Attr or RefAttr.
ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
ValueError: If an output value is None, when outputs is specified.
ValueError: If an output value has a producer set already, when outputs is specified.
"""
self._name = name
self._domain: str = domain
Expand All @@ -932,9 +940,7 @@ def __init__(
# If necessary, we can cache the inputs and outputs as tuples.
self._inputs: tuple[Value | None, ...] = tuple(inputs)
# Values belong to their defining nodes. The values list is immutable
self._outputs: tuple[Value, ...] = tuple(
Value(self, index=i) for i in range(num_outputs)
)
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
attributes = tuple(attributes)
if attributes and not isinstance(attributes[0], (Attr, RefAttr)):
raise TypeError(
Expand Down Expand Up @@ -962,6 +968,54 @@ def __init__(
if self._graph is not None:
self._graph.append(self)

def _create_outputs(
self, num_outputs: int | None, outputs: Sequence[Value] | None
) -> tuple[Value, ...]:
"""Check the parameters and create outputs for the node.
Args:
num_outputs: The number of outputs of the node.
outputs: The output values of the node.
Returns:
The output values of the node.
Raises:
ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
ValueError: If an output value is None.
ValueError: If an output value has a producer set already.
"""
# Check num_outputs and outputs are consistent
if num_outputs is not None and outputs is not None and num_outputs != len(outputs):
raise ValueError(
"num_outputs must be the same as len(outputs) when num_outputs is specified."
"num_outputs: {num_outputs}, outputs: {outputs}"
)
# 1. If outputs is specified (can be empty []), use the outputs
if outputs is not None:
# Check all output values are valid first
for output in outputs:
if output is None:
raise ValueError(f"Output value cannot be None. All outputs: {outputs}")
if output.producer() is not None:
raise ValueError(
f"Supplied output value cannot have a producer when used for initializing a Node. "
f"Output: {output}. All outputs: {outputs}"
)
result = []
for i, output in enumerate(outputs):
output._producer = self # pylint: disable=protected-access
output._index = i # pylint: disable=protected-access
result.append(output)
return tuple(result)

# 2. If num_outputs is specified, create num_outputs outputs
if num_outputs is None:
# Default to 1 output
num_outputs = 1
assert num_outputs is not None
return tuple(Value(self, index=i) for i in range(num_outputs))

def __str__(self) -> str:
node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * (
self._overload != ""
Expand Down
47 changes: 46 additions & 1 deletion onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,58 @@ def setUp(self) -> None:
self.v1 = _core.Value(None, index=None)
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)

def test_initialize_with_values(self):
def test_init_with_values(self):
self.assertEqual(self.node.domain, "test")
self.assertEqual(self.node.op_type, "TestOp")
self.assertEqual(self.node.inputs, (self.v0, self.v1))
self.assertEqual(len(self.node.outputs), 3)
self.assertEqual(self.node.attributes, {})

def test_init_with_preinitialized_outputs(self):
out_1 = _core.Value(
None,
index=None,
name="out_1",
shape=_core.Shape([1]),
type=_core.TensorType(_enums.DataType.BFLOAT16),
)
out_2 = _core.Value(
None,
index=None,
name="out_2",
shape=_core.Shape([2]),
type=_core.TensorType(_enums.DataType.INT4),
)
node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[out_1, out_2])
self.assertEqual(node.outputs[0].name, "out_1")
self.assertEqual(node.outputs[0].shape, _core.Shape([1]))
self.assertEqual(node.outputs[0].dtype, _enums.DataType.BFLOAT16)
self.assertEqual(node.outputs[1].name, "out_2")
self.assertEqual(node.outputs[1].shape, _core.Shape([2]))
self.assertEqual(node.outputs[1].dtype, _enums.DataType.INT4)
self.assertIs(node.outputs[0], out_1)
self.assertIs(node.outputs[1], out_2)
self.assertIs(node.outputs[0].producer(), node)
self.assertIs(node.outputs[1].producer(), node)
self.assertIs(node.outputs[0].index(), 0)
self.assertIs(node.outputs[1].index(), 1)

def test_init_raises_when_num_outputs_does_not_match_outputs(self):
with self.assertRaisesRegex(ValueError, "outputs"):
_core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=2, outputs=[])

def test_init_with_zero_num_outputs(self):
node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=0)
self.assertEqual(node.outputs, ())

def test_init_with_empty_outputs(self):
node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[])
self.assertEqual(node.outputs, ())

def test_init_produces_one_output_with_unspecified_output_argument(self):
node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1))
self.assertEqual(len(node.outputs), 1)

def test_metadata(self):
self.node.meta["test"] = 1
self.assertEqual(self.node.meta["test"], 1)
Expand Down

0 comments on commit 1e4b585

Please sign in to comment.