Skip to content

Commit

Permalink
Merge branch 'main' into embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Oct 12, 2024
2 parents 979f887 + 6d5b578 commit 78f25b2
Show file tree
Hide file tree
Showing 14 changed files with 2,004 additions and 639 deletions.
9 changes: 8 additions & 1 deletion cirkit/templates/circuit_templates/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def image_data(
'quad-tree-2' (the Quad-Tree with two splits per region node),
'quad-tree-4' (the Quad-Tree with four splits per region node),
'quad-graph' (the Quad-Graph region graph),
'random-binary-tree' (the random binary tree on flattened image pixels),
'poon-domingos' (the Poon-Domingos architecture).
input_layer: The name of the input layer. It can be one of the following:
'categorical' (encoding a Categorical distribution over pixel channel values),
Expand All @@ -53,7 +54,13 @@ def image_data(
Raises:
ValueError: If one of the arguments is not one of the specified allowed ones.
"""
if region_graph not in ["quad-tree-2", "quad-tree-4", "quad-graph", "poon-domingos"]:
if region_graph not in [
"quad-tree-2",
"quad-tree-4",
"quad-graph",
"random-binary-tree",
"poon-domingos",
]:
raise ValueError(f"Unknown region graph called {region_graph}")
if input_layer not in ["categorical", "binomial", "embedding"]:
raise ValueError(f"Unknown input layer called {input_layer}")
Expand Down
10 changes: 9 additions & 1 deletion cirkit/templates/circuit_templates/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
TensorParameter,
UnaryParameterOp,
)
from cirkit.templates.region_graph import PoonDomingos, QuadGraph, QuadTree, RegionGraph
from cirkit.templates.region_graph import (
PoonDomingos,
QuadGraph,
QuadTree,
RandomBinaryTree,
RegionGraph,
)
from cirkit.utils.scope import Scope


Expand Down Expand Up @@ -64,6 +70,8 @@ def build_image_region_graph(
return QuadTree(image_shape, num_patch_splits=4)
case "quad-graph":
return QuadGraph(image_shape)
case "random-binary-tree":
return RandomBinaryTree(np.prod(image_shape))
case "poon-domingos":
delta = max(np.ceil(image_shape[0] / 8), np.ceil(image_shape[1] / 8))
return PoonDomingos(image_shape, delta=delta)
Expand Down
1 change: 1 addition & 0 deletions cirkit/templates/region_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .algorithms import ChowLiuTree as ChowLiuTree
from .algorithms import FullyFactorized as FullyFactorized
from .algorithms import LinearTree as LinearTree
from .algorithms import PoonDomingos as PoonDomingos
Expand Down
Loading

0 comments on commit 78f25b2

Please sign in to comment.