Skip to content

Commit

Permalink
Merge pull request #278 from LennertDeSmet/release
Browse files Browse the repository at this point in the history
Add sampling implementation with queries, plus the Binomial layer
  • Loading branch information
loreloc authored Oct 9, 2024
2 parents 1d4f8f3 + 4633fc4 commit e5f1b5c
Show file tree
Hide file tree
Showing 16 changed files with 1,740 additions and 20 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,11 @@ cython_debug/

# Notebooks data
notebooks/datasets/
datasets/MNIST/raw/t10k-images-idx3-ubyte
datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
datasets/MNIST/raw/t10k-labels-idx1-ubyte
datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
datasets/MNIST/raw/train-images-idx3-ubyte
datasets/MNIST/raw/train-images-idx3-ubyte.gz
datasets/MNIST/raw/train-labels-idx1-ubyte
datasets/MNIST/raw/train-labels-idx1-ubyte.gz
16 changes: 9 additions & 7 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def lookup(

# Catch the case there are no inputs coming from other modules
# That is, we are gathering the inputs of input layers
assert in_graph is not None
assert isinstance(entry.module, TorchInputLayer)
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
x = in_graph[..., entry.module.scope_idx].permute(2, 1, 0, 3)
yield entry.module, (x,)
if in_graph is None:
yield entry.module, ()
else:
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
x = in_graph[..., entry.module.scope_idx].permute(2, 1, 0, 3)
yield entry.module, (x,)

@classmethod
def from_index_info(
Expand Down Expand Up @@ -171,7 +173,7 @@ def _evaluate_layers(self, x: Tensor) -> Tensor:
class TorchCircuit(AbstractTorchCircuit):
"""The tensorized circuit with concrete computational graph in PyTorch.
This class is aimed for computation, and therefore does not include strutural properties.
This class is aimed for computation, and therefore does not include structural properties.
"""

def __call__(self, x: Tensor) -> Tensor:
Expand Down
69 changes: 69 additions & 0 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC
from typing import Any

import einops as E
import torch
from torch import Tensor

from cirkit.backend.torch.layers.base import TorchLayer
Expand Down Expand Up @@ -38,6 +40,9 @@ def __init__(
def fold_settings(self) -> tuple[Any, ...]:
return self.num_input_units, self.num_output_units, self.arity

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
raise TypeError(f"Sampling not implemented for {type(self)}")


class TorchProductLayer(TorchInnerLayer, ABC):
...
Expand Down Expand Up @@ -88,6 +93,12 @@ def forward(self, x: Tensor) -> Tensor:
"""
return self.semiring.prod(x, dim=1, keepdim=False) # shape (F, H, B, K) -> (F, B, K).

def sample(self, x: Tensor) -> tuple[Tensor, None]:
# Concatenate samples over disjoint variables through a sum
# x: (F, H, C, K, num_samples, D)
x = torch.sum(x, dim=1) # (F, C, K, num_samples, D)
return x, None


class TorchKroneckerLayer(TorchProductLayer):
"""The Kronecker product layer."""
Expand Down Expand Up @@ -133,6 +144,14 @@ def forward(self, x: Tensor) -> Tensor:
# shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
return self.semiring.mul(x0, x1).flatten(start_dim=-2)

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
# x: (F, H, C, K, num_samples, D)
x0 = x[:, 0].unsqueeze(dim=3) # (F, C, Ki, 1, num_samples, D)
x1 = x[:, 1].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
# shape (F, C, Ki, Ki, num_samples, D) -> (F, C, Ko=Ki**2, num_samples, D)
x = x0 + x1
return torch.flatten(x, start_dim=2, end_dim=3), None


class TorchDenseLayer(TorchSumLayer):
"""The sum layer for dense sum within a layer."""
Expand Down Expand Up @@ -188,6 +207,31 @@ def forward(self, x: Tensor) -> Tensor:
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
) # shape (F, B, Ko).

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
if negative:
raise ValueError("Sampling only works with positive weights")
normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
if not normalized:
raise ValueError("Sampling only works with a normalized parametrization")

# x: (F, H, C, K, num_samples, D)
c = x.shape[2]
d = x.shape[-1]
num_samples = x.shape[-2]

# mixing_distribution: (F, O, K)
mixing_distribution = torch.distributions.Categorical(probs=weight)

mixing_samples = mixing_distribution.sample((num_samples,))
mixing_samples = E.rearrange(mixing_samples, "n f o -> f o n")
mixing_indices = E.repeat(mixing_samples, "f o n -> f a c o n d", a=self.arity, c=c, d=d)

x = torch.gather(x, dim=-3, index=mixing_indices)
x = x[:, 0]
return x, mixing_samples


class TorchMixingLayer(TorchSumLayer):
"""The sum layer for mixture among layers.
Expand Down Expand Up @@ -242,3 +286,28 @@ def forward(self, x: Tensor) -> Tensor:
return self.semiring.einsum(
"fhbk,fkh->fbk", inputs=(x,), operands=(weight,), dim=1, keepdim=False
)

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
if negative:
raise ValueError("Sampling only works with positive weights")
normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
if not normalized:
raise ValueError("Sampling only works with a normalized parametrization")

# x: (F, H, C, K, num_samples, D)
c = x.shape[2]
k = x.shape[-3]
d = x.shape[-1]
num_samples = x.shape[-2]

# mixing_distribution: (F, O, K)
mixing_distribution = torch.distributions.Categorical(probs=weight)

mixing_samples = mixing_distribution.sample((num_samples,))
mixing_samples = E.rearrange(mixing_samples, "n f k -> f k n")
mixing_indices = E.repeat(mixing_samples, "f k n -> f 1 c k n d", c=c, k=k, d=d)

x = torch.gather(x, 1, mixing_indices)[:, 0]
return x, mixing_samples
132 changes: 128 additions & 4 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional

import torch
from torch import Tensor, distributions
Expand All @@ -13,7 +13,7 @@
class TorchInputLayer(TorchLayer, ABC):
"""The abstract base class for input layers."""

# NOTE: We use exactly the sae interface (F, H, B, K) -> (F, B, K) for __call__ of input layers:
# NOTE: We use exactly the safe interface (F, H, B, K) -> (F, B, K) for __call__ of input layers:
# 1. Define arity(H)=num_channels(C), reusing the H dimension.
# 2. Define num_input_units(K)=num_vars(D), which reuses the K dimension.
# For dimension D (variables), we should parse the input in circuit according to the
Expand Down Expand Up @@ -87,6 +87,10 @@ def forward(self, x: Tensor) -> Tensor:
def integrate(self) -> Tensor:
...

@abstractmethod
def sample(self, num_samples: int = 1, x: Tensor | None = None) -> Tensor:
...

def extra_repr(self) -> str:
return (
" ".join(
Expand Down Expand Up @@ -199,7 +203,8 @@ def __init__(
num_output_units: The number of output units.
num_channels: The number of channels.
num_categories: The number of categories for Categorical distribution. Defaults to 2.
logits: The reparameterization for layer parameters.
probs: The reparameterization for layer probs parameters.
logits: The reparameterization for layer logits parameters.
"""
num_variables = scope_idx.shape[-1]
if num_variables != 1:
Expand Down Expand Up @@ -266,6 +271,116 @@ def log_partition_function(self) -> Tensor:
logits = self.logits()
return torch.sum(torch.logsumexp(logits, dim=3), dim=2).unsqueeze(dim=1)

def sample(self, num_samples: int = 1, x: Tensor | None = None) -> Tensor:
raise TypeError("Sampling is not implemented for Categorical layers")


class TorchBinomialLayer(TorchExpFamilyLayer):
"""The Binomial distribution layer.
This is fully factorized down to univariate Binomial distributions.
"""

# DISABLE: It's designed to have these arguments.
# pylint: disable-next=too-many-arguments
def __init__(
self,
scope_idx: Tensor,
num_output_units: int,
*,
num_channels: int = 1,
total_count: int = 1,
probs: Optional[TorchParameter] = None,
logits: Optional[TorchParameter] = None,
semiring: Optional[Semiring] = None,
) -> None:
"""Init class.
Args:
scope_idx: A tensor of shape (F, D), where F is the number of folds, and
D is the number of variables on which the input layers in each fold are defined on.
Alternatively, a tensor of shape (D,) can be specified, which will be interpreted
as a tensor of shape (1, D), i.e., with F = 1.
num_output_units: The number of output units.
num_channels: The number of channels.
total_count: The number of trails. Defaults to 1.
probs: The reparameterization for layer probs parameters.
logits: The reparameterization for layer logits parameters.
"""
num_variables = scope_idx.shape[-1]
if num_variables != 1:
raise ValueError("The Binomial layer encodes a univariate distribution")
if total_count < 0:
raise ValueError("The number of trials must be non-negative")
super().__init__(
scope_idx,
num_output_units,
num_channels=num_channels,
semiring=semiring,
)
self.total_count = total_count
if not ((logits is None) ^ (probs is None)):
raise ValueError("Exactly one between 'logits' and 'probs' must be specified")
if logits is None:
assert probs is not None
if not self._valid_parameter_shape(probs):
raise ValueError(f"The number of folds and shape of 'probs' must match the layer's")
else:
if not self._valid_parameter_shape(logits):
raise ValueError(
f"The number of folds and shape of 'logits' must match the layer's"
)
self.probs = probs
self.logits = logits

def _valid_parameter_shape(self, p: TorchParameter) -> bool:
if p.num_folds != self.num_folds:
return False
return p.shape == (
self.num_output_units,
self.num_channels,
)

@property
def config(self) -> dict[str, Any]:
config = super().config
config.update(total_count=self.total_count)
return config

@property
def params(self) -> dict[str, TorchParameter]:
if self.logits is None:
return dict(probs=self.probs)
return dict(logits=self.logits)

def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
if x.is_floating_point():
x = x.long() # The input to Binomial should be discrete
x = x.permute(0, 2, 3, 1) # (F, C, B, 1) -> (F, B, 1, C)
if self.logits is not None:
logits = self.logits().unsqueeze(dim=1) # (F, 1, K, C)
dist = distributions.Binomial(self.total_count, logits=logits)
else:
probs = self.probs().unsqueeze(dim=1) # (F, 1, K, C)
dist = distributions.Binomial(self.total_count, probs=probs)
x = dist.log_prob(x) # (F, B, K, C)
return torch.sum(x, dim=3)

def log_partition_function(self) -> Tensor:
if self.logits is None:
return torch.zeros(
size=(self.num_folds, 1, self.num_output_units), device=self.probs.device
)
logits = self.logits()
return torch.sum(torch.logsumexp(logits, dim=3), dim=2).unsqueeze(dim=1)

def sample(self, num_samples: int = 1, x: Tensor | None = None) -> Tensor:
logits = torch.log(self.probs()) if self.logits is None else self.logits()
dist = distributions.Binomial(self.total_count, logits=logits)
samples = dist.sample((num_samples,)) # (N, F, K, C)
samples = samples.permute(1, 3, 2, 0) # (F, C, K, N)
return samples


class TorchGaussianLayer(TorchExpFamilyLayer):
"""The Normal distribution layer.
Expand Down Expand Up @@ -348,6 +463,9 @@ def log_partition_function(self) -> Tensor:
log_partition = self.log_partition() # (F, K, C)
return torch.sum(log_partition, dim=2).unsqueeze(dim=1)

def sample(self, num_samples: int = 1, x: Tensor | None = None) -> Tensor:
raise TypeError("Sampling is not implemented for Gaussian layers")


class TorchLogPartitionLayer(TorchInputLayer):
def __init__(
Expand Down Expand Up @@ -404,6 +522,9 @@ def forward(self, x: Tensor) -> Tensor:
def integrate(self) -> Tensor:
raise TypeError("Cannot integrate a layer computing a log-partition function")

def sample(self, num_samples: int = 1, x: Tensor | None = None) -> Tensor:
raise TypeError("Cannot sample from a layer computing a log-partition function")


# TODO: could be in backends/torch/utils, can be reused by PolyGaussian
def polyval(coeff: Tensor, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -500,4 +621,7 @@ def forward(self, x: Tensor) -> Tensor:
return self.semiring.map_from(polyval(coeff, x), SumProductSemiring)

def integrate(self) -> Tensor:
raise TypeError("Cannot integrate a PolynomialLayer")
raise TypeError("Cannot integrate a Polynomial layer")

def sample(self, num_samples: int = 1, x: Tensor | None = None) -> Tensor:
raise TypeError("Cannot sample from a Polynomial layer")
29 changes: 29 additions & 0 deletions cirkit/backend/torch/layers/optimized.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC
from typing import Any

import einops as E
import torch
from torch import Tensor

from cirkit.backend.torch.layers import TorchInnerLayer, TorchSumLayer
Expand Down Expand Up @@ -133,6 +135,33 @@ def forward(self, x: Tensor) -> Tensor:
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
)

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
if negative:
raise ValueError("Sampling only works with positive weights")
normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
if not normalized:
raise ValueError("Sampling only works with a normalized parametrization")

# x: (F, H, C, K, num_samples, D)
x = torch.sum(x, dim=1, keepdim=True) # (F, H=1, C, K, num_samples, D)

c = x.shape[2]
d = x.shape[-1]
num_samples = x.shape[-2]

# mixing_distribution: (F, O, K)
mixing_distribution = torch.distributions.Categorical(probs=weight)

mixing_samples = mixing_distribution.sample((num_samples,))
mixing_samples = E.rearrange(mixing_samples, "n f o -> f o n")
mixing_indices = E.repeat(mixing_samples, "f o n -> f a c o n d", a=1, c=c, d=d)

x = torch.gather(x, dim=-3, index=mixing_indices)
x = x[:, 0]
return x, mixing_samples


class TorchTensorDotLayer(TorchSumLayer):
"""The sum layer for dense sum within a layer."""
Expand Down
Loading

0 comments on commit e5f1b5c

Please sign in to comment.