Skip to content

Commit

Permalink
fix RGNodeMetadata typing and RG is_compatible method
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Oct 9, 2024
1 parent 0232014 commit 9687c32
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions cirkit/templates/region_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator
from functools import cached_property
from typing import TypeAlias, TypedDict, Union, cast, final
from typing import TypeAlias, TypedDict, cast, final

import numpy as np
from numpy.typing import NDArray

from cirkit.utils.algorithms import DiAcyclicGraph
from cirkit.utils.scope import Scope

RGNodeMetadata: TypeAlias = dict[str, Union[int, float, str, bool]]
RGNodeMetadata: TypeAlias = dict[str, int | float | str | bool]


class RegionDict(TypedDict):
Expand Down Expand Up @@ -67,14 +67,10 @@ def __repr__(self) -> str:
class RegionNode(RegionGraphNode):
"""The region node in the region graph."""

...


class PartitionNode(RegionGraphNode):
"""The partition node in the region graph."""

...


# We mark RG as final to hint that RG algorithms should not be its subclasses but factories, so that
# constructed RGs and loaded RGs are all of type RegionGraph.
Expand All @@ -95,16 +91,19 @@ def _check_structure(self):
for ptn in node_children:
if not isinstance(ptn, PartitionNode):
raise ValueError(
f"Expected partition node as children of '{node}', but found '{ptn}'"
f"Expected partition node as children of '{node}', "
f"but found '{ptn}'"
)
if ptn.scope != node.scope:
raise ValueError(
f"Expectet partition node with scope '{node.scope}', but found '{ptn.scope}"
f"Expected partition node with scope '{node.scope}', "
f"but found '{ptn.scope}"
)
continue
if not isinstance(node, PartitionNode):
raise ValueError(
f"Region graph nodes must be either partition nodes or region nodes, found '{type(node)}'"
f"Region graph nodes must be either partition nodes or region nodes, "
f"found '{type(node)}'"
)
scopes = []
for rgn in node_children:
Expand Down Expand Up @@ -217,14 +216,17 @@ def is_compatible(self, other: "RegionGraph", /, *, scope: Iterable[int] | None
if partition1.scope & scope != partition2.scope & scope:
continue # Only check partitions with the same scope.

if any(partition1.scope <= input.scope for input in partition2.inputs) or any(
partition2.scope <= input.scope for input in partition1.inputs
partition1_inputs = self.node_inputs(partition1)
partition2_inputs = self.node_inputs(partition2)

if any(partition1.scope <= input.scope for input in partition2_inputs) or any(
partition2.scope <= input.scope for input in partition1_inputs
):
continue # Only check partitions not within another partition.

adj_mat = np.zeros((len(partition1.inputs), len(partition2.inputs)), dtype=np.bool_)
adj_mat = np.zeros((len(partition1_inputs), len(partition2_inputs)), dtype=np.bool_)
for (i, region1), (j, region2) in itertools.product(
enumerate(partition1.inputs), enumerate(partition2.inputs)
enumerate(partition1_inputs), enumerate(partition2_inputs)
):
# I.e., if scopes intersect over the scope to test.
adj_mat[i, j] = bool(region1.scope & region2.scope & scope)
Expand Down

0 comments on commit 9687c32

Please sign in to comment.