From fa0f5a4d4c13f54e5247ea16f43b928d4850c34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:56:48 +0100 Subject: [PATCH] feat(py): Allow pre-declaring a `Function`'s output types (#1417) This is required to use the function in recursive calls (where the `Function` hasn't been completely defined yet). --------- Co-authored-by: Seyon Sivarajah --- hugr-py/src/hugr/dfg.py | 43 +++++++++++++++++++++++++++++++- hugr-py/tests/test_hugr_build.py | 21 ++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index f1ea17cd0..8fa9e0082 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -41,6 +41,7 @@ def define_function( self, name: str, input_types: TypeRow, + output_types: TypeRow | None = None, type_params: list[TypeParam] | None = None, parent: ToNode | None = None, ) -> Function: @@ -49,6 +50,8 @@ def define_function( Args: name: The name of the function. input_types: The input types for the function. + output_types: The output types for the function. + If not provided, it will be inferred after the function is built. type_params: The type parameters for the function, if polymorphic. parent: The parent node of the constant. Defaults to the root node. @@ -57,7 +60,10 @@ def define_function( """ parent_node = parent or self.hugr.root parent_op = ops.FuncDefn(name, input_types, type_params or []) - return Function.new_nested(parent_op, self.hugr, parent_node) + func = Function.new_nested(parent_op, self.hugr, parent_node) + if output_types is not None: + func.declare_outputs(output_types) + return func def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node: """Add a static constant to the graph. @@ -684,3 +690,38 @@ def __init__( ) -> None: root_op = ops.FuncDefn(name, input_types, type_params or []) super().__init__(root_op) + + def declare_outputs(self, output_types: TypeRow) -> None: + """Declare the output types of the function. + + This is required when calling a function which hasn't been completely + defined yet. The wires passed to :meth:`set_outputs` must match the + declared output types. + """ + self._set_parent_output_count(len(output_types)) + self.parent_op._set_out_types(output_types) + + def set_outputs(self, *args: Wire) -> None: + """Set the outputs of the dataflow graph. + Connects wires to the output node. + + If :meth:`declare_outputs` has been called, the wire types must match + the declared output types. + + Args: + args: Wires to connect to the output node. + + Example: + >>> dfg = Dfg(tys.Bool) + >>> dfg.set_outputs(dfg.inputs()[0]) # connect input to output + """ + if self.parent_op._outputs is not None: + arg_types = [self._get_dataflow_type(w) for w in args] + if arg_types != self.parent_op._outputs: + error_message = ( + f"The function has fixed output type {self.parent_op._outputs}, " + f"but was given output wires with types {arg_types}." + ) + raise ValueError(error_message) + + super().set_outputs(*args) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 5ae308f26..48f544360 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -274,6 +274,27 @@ def test_mono_function(direct_call: bool) -> None: validate(mod.hugr) +def test_recursive_function() -> None: + mod = Module() + + f_recursive = mod.define_function("recurse", [tys.Qubit]) + f_recursive.declare_outputs([tys.Qubit]) + call = f_recursive.call(f_recursive, f_recursive.input_node[0]) + f_recursive.set_outputs(call) + + validate(mod.hugr) + + +def test_invalid_recursive_function() -> None: + mod = Module() + + f_recursive = mod.define_function("recurse", [tys.Bool], [tys.Qubit]) + f_recursive.call(f_recursive, f_recursive.input_node[0]) + + with pytest.raises(ValueError, match="The function has fixed output type"): + f_recursive.set_outputs(f_recursive.input_node[0]) + + def test_higher_order() -> None: noop_fn = Dfg(tys.Qubit) noop_fn.set_outputs(noop_fn.add(ops.Noop()(noop_fn.input_node[0])))