Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix categorical nans #317

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 9 additions & 25 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,6 @@
raise ValueError(f"The number of folds and shape of 'probs' must match the layer's")
self.probs = probs
self.logits = logits
self.idx_mode = (
len(torch.unique(self.scope_idx)) > 4096 or self.num_categories > 256
)

def _valid_parameter_shape(self, p: TorchParameter) -> bool:
if p.num_folds != self.num_folds:
Expand Down Expand Up @@ -326,30 +323,17 @@
def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
if x.is_floating_point():
x = x.long() # The input to Categorical should be discrete
# x: (F, C, B, 1) -> (F, C, B)
x = x.squeeze(dim=3)
# logits: (F, K, C, N)
logits = torch.log(self.probs()) if self.logits is None else self.logits()
if self.idx_mode:
if self.num_channels == 1:
x = (
logits[:, :, 0, :]
.transpose(1, 2)[range(self.num_folds), x[:, 0, :, 0].t()]
.transpose(0, 1)
)
else:
x = x[..., 0].permute(2, 0, 1)
x = (
logits[
torch.arange(self.num_folds).unsqueeze(1),
:,
torch.arange(self.num_channels).unsqueeze(0),
x,
]
.sum(2)
.transpose(0, 1)
)
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds, device=logits.device)
x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]]
else:
x = F.one_hot(x, self.num_categories) # (F, C, B, 1, num_categories)
x = x.squeeze(dim=3) # (F, C, B, num_categories)
x = torch.einsum("fcbi,fkci->fbk", x.to(logits.dtype), logits)
idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1)

Check warning on line 336 in cirkit/backend/torch/layers/input.py

View check run for this annotation

Codecov / codecov/patch

cirkit/backend/torch/layers/input.py#L334-L336

Added lines #L334 - L336 were not covered by tests
return x

def log_partition_function(self) -> Tensor:
Expand Down
22 changes: 11 additions & 11 deletions notebooks/compilation-options.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.57 s, sys: 950 ms, total: 5.52 s\n",
"Wall time: 5.45 s\n"
"CPU times: user 4.56 s, sys: 1.09 s, total: 5.66 s\n",
"Wall time: 5.56 s\n"
]
}
],
Expand Down Expand Up @@ -273,7 +273,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.2 s ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"1.17 s ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -338,8 +338,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.48 s, sys: 991 ms, total: 5.47 s\n",
"Wall time: 5.39 s\n"
"CPU times: user 4.65 s, sys: 967 ms, total: 5.62 s\n",
"Wall time: 5.52 s\n"
]
}
],
Expand Down Expand Up @@ -420,7 +420,7 @@
"id": "f074e168-dee4-4234-8eae-afd28fae317f",
"metadata": {},
"source": [
"As we see in the next code snippet, enabling folding provided an (approximately) **20x speed-up** for feed-forward circuit evaluations."
"As we see in the next code snippet, enabling folding provided an (approximately) **19.9x speed-up** for feed-forward circuit evaluations."
]
},
{
Expand All @@ -433,7 +433,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"60 ms ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"58.9 ms ± 6.43 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand Down Expand Up @@ -527,8 +527,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.81 s, sys: 876 ms, total: 5.68 s\n",
"Wall time: 5.61 s\n"
"CPU times: user 5.02 s, sys: 929 ms, total: 5.95 s\n",
"Wall time: 5.85 s\n"
]
}
],
Expand Down Expand Up @@ -591,7 +591,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"26.7 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"25.6 ms ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand All @@ -608,7 +608,7 @@
"id": "11d95c02-2c66-4414-b676-0dec303f2aa9",
"metadata": {},
"source": [
"Note that, we achieved an (approximately) **2.2x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **44.9x speed-up**, when compared to the circuit compiled with no folding and no optimizations."
"Note that, we achieved an (approximately) **2.3x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **45.7x speed-up**, when compared to the circuit compiled with no folding and no optimizations."
]
},
{
Expand Down
Loading