diff --git a/cirkit/templates/region_graph/graph.py b/cirkit/templates/region_graph/graph.py index eac1370a..694d8757 100644 --- a/cirkit/templates/region_graph/graph.py +++ b/cirkit/templates/region_graph/graph.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Iterable, Iterator from functools import cached_property -from typing import TypeAlias, TypedDict, Union, cast, final +from typing import TypeAlias, TypedDict, cast, final import numpy as np from numpy.typing import NDArray @@ -12,7 +12,7 @@ from cirkit.utils.algorithms import DiAcyclicGraph from cirkit.utils.scope import Scope -RGNodeMetadata: TypeAlias = dict[str, Union[int, float, str, bool]] +RGNodeMetadata: TypeAlias = dict[str, int | float | str | bool] class RegionDict(TypedDict): @@ -67,14 +67,10 @@ def __repr__(self) -> str: class RegionNode(RegionGraphNode): """The region node in the region graph.""" - ... - class PartitionNode(RegionGraphNode): """The partition node in the region graph.""" - ... - # We mark RG as final to hint that RG algorithms should not be its subclasses but factories, so that # constructed RGs and loaded RGs are all of type RegionGraph. @@ -95,16 +91,19 @@ def _check_structure(self): for ptn in node_children: if not isinstance(ptn, PartitionNode): raise ValueError( - f"Expected partition node as children of '{node}', but found '{ptn}'" + f"Expected partition node as children of '{node}', " + f"but found '{ptn}'" ) if ptn.scope != node.scope: raise ValueError( - f"Expectet partition node with scope '{node.scope}', but found '{ptn.scope}" + f"Expected partition node with scope '{node.scope}', " + f"but found '{ptn.scope}" ) continue if not isinstance(node, PartitionNode): raise ValueError( - f"Region graph nodes must be either partition nodes or region nodes, found '{type(node)}'" + f"Region graph nodes must be either partition nodes or region nodes, " + f"found '{type(node)}'" ) scopes = [] for rgn in node_children: @@ -217,14 +216,17 @@ def is_compatible(self, other: "RegionGraph", /, *, scope: Iterable[int] | None if partition1.scope & scope != partition2.scope & scope: continue # Only check partitions with the same scope. - if any(partition1.scope <= input.scope for input in partition2.inputs) or any( - partition2.scope <= input.scope for input in partition1.inputs + partition1_inputs = self.node_inputs(partition1) + partition2_inputs = self.node_inputs(partition2) + + if any(partition1.scope <= input.scope for input in partition2_inputs) or any( + partition2.scope <= input.scope for input in partition1_inputs ): continue # Only check partitions not within another partition. - adj_mat = np.zeros((len(partition1.inputs), len(partition2.inputs)), dtype=np.bool_) + adj_mat = np.zeros((len(partition1_inputs), len(partition2_inputs)), dtype=np.bool_) for (i, region1), (j, region2) in itertools.product( - enumerate(partition1.inputs), enumerate(partition2.inputs) + enumerate(partition1_inputs), enumerate(partition2_inputs) ): # I.e., if scopes intersect over the scope to test. adj_mat[i, j] = bool(region1.scope & region2.scope & scope)