From 8a14e7ece42bf9f9b98d28c2c3cad8e772a7d31b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 6 Jul 2023 20:06:13 +0200 Subject: [PATCH] Allow rebuilding a graph in toposort_replace --- .../model_transform/conditioning.py | 3 +- .../tests/utils/test_pytensorf.py | 71 +++++++ pymc_experimental/utils/model_fgraph.py | 17 +- pymc_experimental/utils/pytensorf.py | 188 +++++++++++++++++- 4 files changed, 259 insertions(+), 20 deletions(-) create mode 100644 pymc_experimental/tests/utils/test_pytensorf.py diff --git a/pymc_experimental/model_transform/conditioning.py b/pymc_experimental/model_transform/conditioning.py index fb4468c87..caaf82cce 100644 --- a/pymc_experimental/model_transform/conditioning.py +++ b/pymc_experimental/model_transform/conditioning.py @@ -14,9 +14,8 @@ model_from_fgraph, model_named, model_observed_rv, - toposort_replace, ) -from pymc_experimental.utils.pytensorf import rvs_in_graph +from pymc_experimental.utils.pytensorf import rvs_in_graph, toposort_replace def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model: diff --git a/pymc_experimental/tests/utils/test_pytensorf.py b/pymc_experimental/tests/utils/test_pytensorf.py new file mode 100644 index 000000000..0f6945710 --- /dev/null +++ b/pymc_experimental/tests/utils/test_pytensorf.py @@ -0,0 +1,71 @@ +import pytensor.tensor as pt +import pytest +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import equal_computations + +from pymc_experimental.utils.pytensorf import toposort_replace + + +class TestToposortReplace: + @pytest.mark.parametrize("compatible_type", (True, False)) + @pytest.mark.parametrize("num_replacements", (1, 2)) + @pytest.mark.parametrize("rebuild", (True, False)) + def test_horizontal_dependency(self, compatible_type, num_replacements, rebuild): + x = pt.vector("x", shape=(5,)) + y = pt.vector("y", shape=(5,)) + + out1 = pt.exp(x + y) + pt.log(x + y) + out2 = pt.cos(out1) + + new_shape = (5,) if compatible_type else (10,) + new_x = pt.vector("new_x", shape=new_shape) + new_y = pt.vector("new_y", shape=new_shape) + if num_replacements == 1: + replacements = [(y, new_y)] + else: + replacements = [(x, new_x), (y, new_y)] + + fg = FunctionGraph([x, y], [out1, out2], clone=False) + + # If types are incompatible, and we don't rebuild or only replace one of the variables, + # The function should fail + if not compatible_type and (not rebuild or num_replacements == 1): + with pytest.raises((TypeError, ValueError)): + toposort_replace(fg, replacements, rebuild=rebuild) + return + toposort_replace(fg, replacements, rebuild=rebuild) + + if num_replacements == 1: + expected_out1 = pt.exp(x + new_y) + pt.log(x + new_y) + else: + expected_out1 = pt.exp(new_x + new_y) + pt.log(new_x + new_y) + expected_out2 = pt.cos(expected_out1) + assert equal_computations(fg.outputs, [expected_out1, expected_out2]) + + @pytest.mark.parametrize("compatible_type", (True, False)) + @pytest.mark.parametrize("num_replacements", (2, 3)) + @pytest.mark.parametrize("rebuild", (True, False)) + def test_vertical_dependency(self, compatible_type, num_replacements, rebuild): + x = pt.vector("x", shape=(5,)) + a1 = pt.exp(x) + a2 = pt.log(a1) + out = a1 + a2 + + new_x = pt.vector("new_x", shape=(5 if compatible_type else 10,)) + if num_replacements == 2: + replacements = [(x, new_x), (a1, pt.cos(a1)), (a2, pt.sin(a2 + 5))] + else: + replacements = [(a1, pt.cos(pt.exp(new_x))), (a2, pt.sin(a2 + 5))] + + fg = FunctionGraph([x], [out], clone=False) + + if not compatible_type and not rebuild: + with pytest.raises(TypeError): + toposort_replace(fg, replacements, rebuild=rebuild) + return + toposort_replace(fg, replacements, rebuild=rebuild) + + expected_a1 = pt.cos(pt.exp(new_x)) + expected_a2 = pt.sin(pt.log(expected_a1) + 5) + expected_out = expected_a1 + expected_a2 + assert equal_computations(fg.outputs, [expected_out]) diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py index b51f44d62..a6b33205c 100644 --- a/pymc_experimental/utils/model_fgraph.py +++ b/pymc_experimental/utils/model_fgraph.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Sequence, Tuple +from typing import Dict, Optional, Tuple import pytensor from pymc.logprob.transforms import RVTransform @@ -10,7 +10,7 @@ from pytensor.scalar import Identity from pytensor.tensor.elemwise import Elemwise -from pymc_experimental.utils.pytensorf import StringType +from pymc_experimental.utils.pytensorf import StringType, toposort_replace class ModelVar(Op): @@ -89,19 +89,6 @@ def model_free_rv(rv, value, transform, *dims): model_named = ModelNamed() -def toposort_replace( - fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False -) -> None: - """Replace multiple variables in topological order.""" - toposort = fgraph.toposort() - sorted_replacements = sorted( - replacements, - key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, - reverse=reverse, - ) - fgraph.replace_all(sorted_replacements, import_missing=True) - - @node_rewriter([Elemwise]) def local_remove_identity(fgraph, node): if isinstance(node.op.scalar_op, Identity): diff --git a/pymc_experimental/utils/pytensorf.py b/pymc_experimental/utils/pytensorf.py index a953b5c16..7b89f29f5 100644 --- a/pymc_experimental/utils/pytensorf.py +++ b/pymc_experimental/utils/pytensorf.py @@ -1,10 +1,12 @@ -from typing import Sequence +from collections import deque +from itertools import chain +from typing import Iterable, Sequence, Set, Tuple import pytensor from pymc import SymbolicRandomVariable from pytensor import Variable -from pytensor.graph import Constant, Type -from pytensor.graph.basic import walk +from pytensor.graph import Constant, FunctionGraph, Type +from pytensor.graph.basic import Apply, walk from pytensor.graph.op import HasInnerGraph from pytensor.tensor.random.op import RandomVariable @@ -58,3 +60,183 @@ def expand(r): for node in walk(vars, expand, False) if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable)) ) + + +# def replace_rebuild_all( +# fgraph: FunctionGraph, +# replacements: Sequence[Tuple[Variable, Variable]], +# reason: Optional[str] = None, +# verbose: Optional[bool] = None, +# import_missing: bool = False, +# ) -> None: +# """Replace a variable in the `FunctionGraph` and rebuild the graph. +# +# This is the main interface to manipulate the subgraph in `FunctionGraph`. +# For every node that uses `var` as input, makes it use `new_var` instead. +# +# Parameters +# ---------- +# fgraph +# The FunctionGraph where replacements are performed +# var +# The variable to be replaced. +# new_var +# The variable to replace `var`. +# reason +# The name of the optimization or operation in progress. +# verbose +# Print `reason`, `var`, and `new_var`. +# import_missing +# Import missing variables. +# +# """ +# # if verbose is None: +# # verbose = config.optimizer_verbose +# # if verbose: +# # print( +# # f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}" +# # ) +# # +# # if var not in fgraph.variables: +# # return +# +# +# +# def get_client_nodes(vars): +# nodes = set() +# d = deque(chain.from_iterable(fgraph.clients[var] for var in vars)) +# while d: +# node, _ = d.pop() +# if node == "output": +# continue +# if node in nodes: +# continue +# nodes.add(node) +# d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs)) +# return nodes +# +# topo_order = {node: order for order, node in enumerate(fgraph.toposort())} +# old_vars = [old for old, _ in replacements] +# d = deque(sorted(get_client_nodes(old_vars), key=lambda node: topo_order[node])) +# +# repl_dict = {old: new for old, new in replacements} +# outputs = set(fgraph.outputs) +# while d: +# node: Apply = d.popleft() +# +# new_inputs = [repl_dict.get(i, i) for i in node.inputs] +# if new_inputs == node.inputs: +# continue +# +# new_node = node.clone_with_new_inputs(new_inputs, strict=False) +# for out, new_out in zip(node.outputs, new_node.outputs): +# repl_dict[out] = new_out +# +# return FunctionGraph(outputs=outputs, clone=False) + + +def _replace_rebuild_all( + fgraph: FunctionGraph, replacements: Iterable[Tuple[Variable, Variable]], **kwargs +) -> FunctionGraph: + """Replace variables and rebuild dependent graph if needed. + + Rebuilding allows for replacements that change the semantics of the graph + (different types), which may not be possible for all Ops. + """ + + def get_client_nodes(vars) -> Set[Apply]: + nodes = set() + d = deque( + chain.from_iterable(fgraph.clients[var] for var in vars if var in fgraph.variables) + ) + while d: + node, _ = d.pop() + if node in nodes or node == "output": + continue + nodes.add(node) + d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs)) + return nodes + + repl_dict = {old: new for old, new in replacements} + root_nodes = {var.owner for var in repl_dict.keys()} + + # Build sorted queue with all nodes that depend on replaced variables + topo_order = {node: order for order, node in enumerate(fgraph.toposort())} + client_nodes = get_client_nodes(repl_dict.keys()) + d = deque(sorted(client_nodes, key=lambda node: topo_order[node])) + while d: + node = d.popleft() + if node in root_nodes: + continue + + new_inputs = [repl_dict.get(i, i) for i in node.inputs] + if new_inputs == node.inputs: + continue + + # Either remake the node or do a simple inplace replacement + # This property is not yet present in PyTensor + if getattr(node.op, "_output_type_depends_on_input_value", False): + remake_node = True + else: + remake_node = any( + not inp.type == new_inp.type for inp, new_inp in zip(node.inputs, new_inputs) + ) + + if remake_node: + new_node = node.clone_with_new_inputs(new_inputs, strict=False) + fgraph.import_node(new_node, import_missing=True) + for out, new_out in zip(node.outputs, new_node.outputs): + repl_dict[out] = new_out + else: + replace = list(zip(node.inputs, new_inputs)) + fgraph.replace_all(replace, import_missing=True) + + # We need special logic for the cases where we had to rebuild the output nodes + for i, (new_output, old_output) in enumerate( + zip( + (repl_dict.get(out, out) for out in fgraph.outputs), + fgraph.outputs, + ) + ): + if new_output is old_output: + continue + fgraph.outputs[i] = new_output + fgraph.import_var(new_output, import_missing=True) + client = ("output", i) + fgraph.add_client(new_output, client) + fgraph.remove_client(old_output, client) + fgraph.execute_callbacks("on_change_input", "output", i, old_output, new_output) + + +def toposort_replace( + fgraph: FunctionGraph, + replacements: Sequence[Tuple[Variable, Variable]], + reverse: bool = False, + rebuild: bool = False, +) -> None: + """Replace multiple variables in topological order.""" + if rebuild and reverse: + raise NotImplementedError("reverse rebuild not supported") + + toposort = fgraph.toposort() + sorted_replacements = sorted( + replacements, + key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1, + reverse=reverse, + ) + + if rebuild: + if len(replacements) > 1: + # In this case we need to introduce the replacements inside each other + # To avoid undoing previous changes + sorted_replacements = [list(pairs) for pairs in sorted_replacements] + for i in range(1, len(replacements)): + # Replace-rebuild each successive replacement with the previous replacements (in topological order) + temp_fgraph = FunctionGraph( + outputs=[repl for _, repl in sorted_replacements[i:]], clone=False + ) + _replace_rebuild_all(temp_fgraph, replacements=sorted_replacements[:i]) + sorted_replacements[i][1] = temp_fgraph.outputs[0] + _replace_rebuild_all(fgraph, sorted_replacements, import_missing=True) + else: + fgraph.replace_all(sorted_replacements, import_missing=True)