Skip to content

Commit

Permalink
Merge branch 'embeddings' of github.com:april-tools/cirkit into embed…
Browse files Browse the repository at this point in the history
…dings
  • Loading branch information
loreloc committed Oct 12, 2024
2 parents 78f25b2 + 8e4872a commit dcf9252
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
11 changes: 5 additions & 6 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,8 @@ def forward(self, x: Tensor) -> Tensor:
x = x.long() # The input to Embedding should be discrete
x = F.one_hot(x, self.num_states) # (F, C, B, 1 num_states)
x = x.squeeze(dim=3) # (F, C, B, num_states)
x = x.to(torch.get_default_dtype())
weight = self.weight()
x = torch.einsum("fcbi,fkci->fbkc", x, weight)
x = torch.einsum("fcbi,fkci->fbkc", x.to(weight.dtype), weight)
x = self.semiring.map_from(x, SumProductSemiring)
return self.semiring.prod(x, dim=-1) # (F, B, K)

Check warning on line 174 in cirkit/backend/torch/layers/input.py

View check run for this annotation

Codecov / codecov/patch

cirkit/backend/torch/layers/input.py#L168-L174

Added lines #L168 - L174 were not covered by tests

Expand Down Expand Up @@ -326,9 +325,8 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
x = x.long() # The input to Categorical should be discrete
x = F.one_hot(x, self.num_categories) # (F, C, B, 1, num_categories)
x = x.squeeze(dim=3) # (F, C, B, num_categories)
x = x.to(torch.get_default_dtype())
logits = torch.log(self.probs()) if self.logits is None else self.logits()
x = torch.einsum("fcbi,fkci->fbk", x, logits)
x = torch.einsum("fcbi,fkci->fbk", x.to(logits.dtype), logits)
return x

def log_partition_function(self) -> Tensor:
Expand Down Expand Up @@ -567,11 +565,12 @@ def __init__(
assert value.num_folds == self.num_folds
assert value.shape == (num_output_units,)
self.value = value
self._source_semiring = LSESumSemiring if log_space else LSESumSemiring
self.log_space = log_space
self._source_semiring = LSESumSemiring if log_space else SumProductSemiring

@property
def config(self) -> Mapping[str, Any]:
return {"num_output_units": self.num_output_units}
return {"num_output_units": self.num_output_units, "log_space": self.log_space}

@property
def params(self) -> Mapping[str, TorchParameter]:
Expand Down
5 changes: 1 addition & 4 deletions cirkit/symbolic/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

from cirkit.symbolic.circuit import CircuitBlock
from cirkit.symbolic.layers import (
BinomialLayer,
CategoricalLayer,
ConstantValueLayer,
DenseLayer,
EmbeddingLayer,
EvidenceLayer,
GaussianLayer,
HadamardLayer,
KroneckerLayer,
Layer,
LayerOperator,
MixingLayer,
Expand Down Expand Up @@ -46,7 +43,7 @@ def integrate_embedding_layer(sl: EmbeddingLayer, *, scope: Scope) -> CircuitBlo
reduce_sum = ReduceSumParameter(sl.weight.shape, axis=2)
reduce_prod = ReduceProductParameter(reduce_sum.shape, axis=1)
value = Parameter.from_sequence(sl.weight.ref(), reduce_sum, reduce_prod)
sl = ConstantValueLayer(sl.num_output_units, value=value)
sl = ConstantValueLayer(sl.num_output_units, log_space=False, value=value)
return CircuitBlock.from_layer(sl)

Check warning on line 47 in cirkit/symbolic/operators.py

View check run for this annotation

Codecov / codecov/patch

cirkit/symbolic/operators.py#L43-L47

Added lines #L43 - L47 were not covered by tests


Expand Down
3 changes: 2 additions & 1 deletion docs/api/backend/torch/layers/input.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
::: cirkit.backend.torch.layers.input.TorchInputLayer
::: cirkit.backend.torch.layers.input.TorchConstantLayer
::: cirkit.backend.torch.layers.input.TorchConstantValueLayer
::: cirkit.backend.torch.layers.input.TorchExpFamilyLayer
::: cirkit.backend.torch.layers.input.TorchCategoricalLayer
::: cirkit.backend.torch.layers.input.TorchGaussianLayer
::: cirkit.backend.torch.layers.input.TorchLogPartitionLayer

0 comments on commit dcf9252

Please sign in to comment.