Skip to content

Commit

Permalink
Allow rebuilding a graph in toposort_replace
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 6, 2023
1 parent dd3c44d commit 8a14e7e
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 20 deletions.
3 changes: 1 addition & 2 deletions pymc_experimental/model_transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
model_from_fgraph,
model_named,
model_observed_rv,
toposort_replace,
)
from pymc_experimental.utils.pytensorf import rvs_in_graph
from pymc_experimental.utils.pytensorf import rvs_in_graph, toposort_replace


def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
Expand Down
71 changes: 71 additions & 0 deletions pymc_experimental/tests/utils/test_pytensorf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytensor.tensor as pt
import pytest
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import equal_computations

from pymc_experimental.utils.pytensorf import toposort_replace


class TestToposortReplace:
@pytest.mark.parametrize("compatible_type", (True, False))
@pytest.mark.parametrize("num_replacements", (1, 2))
@pytest.mark.parametrize("rebuild", (True, False))
def test_horizontal_dependency(self, compatible_type, num_replacements, rebuild):
x = pt.vector("x", shape=(5,))
y = pt.vector("y", shape=(5,))

out1 = pt.exp(x + y) + pt.log(x + y)
out2 = pt.cos(out1)

new_shape = (5,) if compatible_type else (10,)
new_x = pt.vector("new_x", shape=new_shape)
new_y = pt.vector("new_y", shape=new_shape)
if num_replacements == 1:
replacements = [(y, new_y)]
else:
replacements = [(x, new_x), (y, new_y)]

fg = FunctionGraph([x, y], [out1, out2], clone=False)

# If types are incompatible, and we don't rebuild or only replace one of the variables,
# The function should fail
if not compatible_type and (not rebuild or num_replacements == 1):
with pytest.raises((TypeError, ValueError)):
toposort_replace(fg, replacements, rebuild=rebuild)
return
toposort_replace(fg, replacements, rebuild=rebuild)

if num_replacements == 1:
expected_out1 = pt.exp(x + new_y) + pt.log(x + new_y)
else:
expected_out1 = pt.exp(new_x + new_y) + pt.log(new_x + new_y)
expected_out2 = pt.cos(expected_out1)
assert equal_computations(fg.outputs, [expected_out1, expected_out2])

@pytest.mark.parametrize("compatible_type", (True, False))
@pytest.mark.parametrize("num_replacements", (2, 3))
@pytest.mark.parametrize("rebuild", (True, False))
def test_vertical_dependency(self, compatible_type, num_replacements, rebuild):
x = pt.vector("x", shape=(5,))
a1 = pt.exp(x)
a2 = pt.log(a1)
out = a1 + a2

new_x = pt.vector("new_x", shape=(5 if compatible_type else 10,))
if num_replacements == 2:
replacements = [(x, new_x), (a1, pt.cos(a1)), (a2, pt.sin(a2 + 5))]
else:
replacements = [(a1, pt.cos(pt.exp(new_x))), (a2, pt.sin(a2 + 5))]

fg = FunctionGraph([x], [out], clone=False)

if not compatible_type and not rebuild:
with pytest.raises(TypeError):
toposort_replace(fg, replacements, rebuild=rebuild)
return
toposort_replace(fg, replacements, rebuild=rebuild)

expected_a1 = pt.cos(pt.exp(new_x))
expected_a2 = pt.sin(pt.log(expected_a1) + 5)
expected_out = expected_a1 + expected_a2
assert equal_computations(fg.outputs, [expected_out])
17 changes: 2 additions & 15 deletions pymc_experimental/utils/model_fgraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Sequence, Tuple
from typing import Dict, Optional, Tuple

import pytensor
from pymc.logprob.transforms import RVTransform
Expand All @@ -10,7 +10,7 @@
from pytensor.scalar import Identity
from pytensor.tensor.elemwise import Elemwise

from pymc_experimental.utils.pytensorf import StringType
from pymc_experimental.utils.pytensorf import StringType, toposort_replace


class ModelVar(Op):
Expand Down Expand Up @@ -89,19 +89,6 @@ def model_free_rv(rv, value, transform, *dims):
model_named = ModelNamed()


def toposort_replace(
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
) -> None:
"""Replace multiple variables in topological order."""
toposort = fgraph.toposort()
sorted_replacements = sorted(
replacements,
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
reverse=reverse,
)
fgraph.replace_all(sorted_replacements, import_missing=True)


@node_rewriter([Elemwise])
def local_remove_identity(fgraph, node):
if isinstance(node.op.scalar_op, Identity):
Expand Down
188 changes: 185 additions & 3 deletions pymc_experimental/utils/pytensorf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Sequence
from collections import deque
from itertools import chain
from typing import Iterable, Sequence, Set, Tuple

import pytensor
from pymc import SymbolicRandomVariable
from pytensor import Variable
from pytensor.graph import Constant, Type
from pytensor.graph.basic import walk
from pytensor.graph import Constant, FunctionGraph, Type
from pytensor.graph.basic import Apply, walk
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.random.op import RandomVariable

Expand Down Expand Up @@ -58,3 +60,183 @@ def expand(r):
for node in walk(vars, expand, False)
if node.owner and isinstance(node.owner.op, (RandomVariable, SymbolicRandomVariable))
)


# def replace_rebuild_all(
# fgraph: FunctionGraph,
# replacements: Sequence[Tuple[Variable, Variable]],
# reason: Optional[str] = None,
# verbose: Optional[bool] = None,
# import_missing: bool = False,
# ) -> None:
# """Replace a variable in the `FunctionGraph` and rebuild the graph.
#
# This is the main interface to manipulate the subgraph in `FunctionGraph`.
# For every node that uses `var` as input, makes it use `new_var` instead.
#
# Parameters
# ----------
# fgraph
# The FunctionGraph where replacements are performed
# var
# The variable to be replaced.
# new_var
# The variable to replace `var`.
# reason
# The name of the optimization or operation in progress.
# verbose
# Print `reason`, `var`, and `new_var`.
# import_missing
# Import missing variables.
#
# """
# # if verbose is None:
# # verbose = config.optimizer_verbose
# # if verbose:
# # print(
# # f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
# # )
# #
# # if var not in fgraph.variables:
# # return
#
#
#
# def get_client_nodes(vars):
# nodes = set()
# d = deque(chain.from_iterable(fgraph.clients[var] for var in vars))
# while d:
# node, _ = d.pop()
# if node == "output":
# continue
# if node in nodes:
# continue
# nodes.add(node)
# d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs))
# return nodes
#
# topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
# old_vars = [old for old, _ in replacements]
# d = deque(sorted(get_client_nodes(old_vars), key=lambda node: topo_order[node]))
#
# repl_dict = {old: new for old, new in replacements}
# outputs = set(fgraph.outputs)
# while d:
# node: Apply = d.popleft()
#
# new_inputs = [repl_dict.get(i, i) for i in node.inputs]
# if new_inputs == node.inputs:
# continue
#
# new_node = node.clone_with_new_inputs(new_inputs, strict=False)
# for out, new_out in zip(node.outputs, new_node.outputs):
# repl_dict[out] = new_out
#
# return FunctionGraph(outputs=outputs, clone=False)


def _replace_rebuild_all(
fgraph: FunctionGraph, replacements: Iterable[Tuple[Variable, Variable]], **kwargs
) -> FunctionGraph:
"""Replace variables and rebuild dependent graph if needed.
Rebuilding allows for replacements that change the semantics of the graph
(different types), which may not be possible for all Ops.
"""

def get_client_nodes(vars) -> Set[Apply]:
nodes = set()
d = deque(
chain.from_iterable(fgraph.clients[var] for var in vars if var in fgraph.variables)
)
while d:
node, _ = d.pop()
if node in nodes or node == "output":
continue
nodes.add(node)
d.extend(chain.from_iterable(fgraph.clients[out] for out in node.outputs))
return nodes

repl_dict = {old: new for old, new in replacements}
root_nodes = {var.owner for var in repl_dict.keys()}

# Build sorted queue with all nodes that depend on replaced variables
topo_order = {node: order for order, node in enumerate(fgraph.toposort())}
client_nodes = get_client_nodes(repl_dict.keys())
d = deque(sorted(client_nodes, key=lambda node: topo_order[node]))
while d:
node = d.popleft()
if node in root_nodes:
continue

new_inputs = [repl_dict.get(i, i) for i in node.inputs]
if new_inputs == node.inputs:
continue

# Either remake the node or do a simple inplace replacement
# This property is not yet present in PyTensor
if getattr(node.op, "_output_type_depends_on_input_value", False):
remake_node = True
else:
remake_node = any(
not inp.type == new_inp.type for inp, new_inp in zip(node.inputs, new_inputs)
)

if remake_node:
new_node = node.clone_with_new_inputs(new_inputs, strict=False)
fgraph.import_node(new_node, import_missing=True)
for out, new_out in zip(node.outputs, new_node.outputs):
repl_dict[out] = new_out
else:
replace = list(zip(node.inputs, new_inputs))
fgraph.replace_all(replace, import_missing=True)

# We need special logic for the cases where we had to rebuild the output nodes
for i, (new_output, old_output) in enumerate(
zip(
(repl_dict.get(out, out) for out in fgraph.outputs),
fgraph.outputs,
)
):
if new_output is old_output:
continue
fgraph.outputs[i] = new_output
fgraph.import_var(new_output, import_missing=True)
client = ("output", i)
fgraph.add_client(new_output, client)
fgraph.remove_client(old_output, client)
fgraph.execute_callbacks("on_change_input", "output", i, old_output, new_output)


def toposort_replace(
fgraph: FunctionGraph,
replacements: Sequence[Tuple[Variable, Variable]],
reverse: bool = False,
rebuild: bool = False,
) -> None:
"""Replace multiple variables in topological order."""
if rebuild and reverse:
raise NotImplementedError("reverse rebuild not supported")

toposort = fgraph.toposort()
sorted_replacements = sorted(
replacements,
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1,
reverse=reverse,
)

if rebuild:
if len(replacements) > 1:
# In this case we need to introduce the replacements inside each other
# To avoid undoing previous changes
sorted_replacements = [list(pairs) for pairs in sorted_replacements]
for i in range(1, len(replacements)):
# Replace-rebuild each successive replacement with the previous replacements (in topological order)
temp_fgraph = FunctionGraph(
outputs=[repl for _, repl in sorted_replacements[i:]], clone=False
)
_replace_rebuild_all(temp_fgraph, replacements=sorted_replacements[:i])
sorted_replacements[i][1] = temp_fgraph.outputs[0]
_replace_rebuild_all(fgraph, sorted_replacements, import_missing=True)
else:
fgraph.replace_all(sorted_replacements, import_missing=True)

0 comments on commit 8a14e7e

Please sign in to comment.