Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc authored Nov 11, 2024
2 parents 9dd6335 + 3ce47e0 commit 9dc089e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: |
./scripts/coverage.sh --xml
- name: Upload coverage reports to Codecov
if: ${{ github.event.push && (github.event.push.ref == github.event.repository.default_branch) }}$
if: ${{ github.repository == 'april-tools/cirkit' && github.event.push && (github.event.push.ref == github.event.repository.default_branch) }}$
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
34 changes: 9 additions & 25 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,6 @@ def __init__(
raise ValueError(f"The number of folds and shape of 'probs' must match the layer's")
self.probs = probs
self.logits = logits
self.idx_mode = (
len(torch.unique(self.scope_idx)) > 4096 or self.num_categories > 256
)

def _valid_parameter_shape(self, p: TorchParameter) -> bool:
if p.num_folds != self.num_folds:
Expand Down Expand Up @@ -326,30 +323,17 @@ def params(self) -> Mapping[str, TorchParameter]:
def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
if x.is_floating_point():
x = x.long() # The input to Categorical should be discrete
# x: (F, C, B, 1) -> (F, C, B)
x = x.squeeze(dim=3)
# logits: (F, K, C, N)
logits = torch.log(self.probs()) if self.logits is None else self.logits()
if self.idx_mode:
if self.num_channels == 1:
x = (
logits[:, :, 0, :]
.transpose(1, 2)[range(self.num_folds), x[:, 0, :, 0].t()]
.transpose(0, 1)
)
else:
x = x[..., 0].permute(2, 0, 1)
x = (
logits[
torch.arange(self.num_folds).unsqueeze(1),
:,
torch.arange(self.num_channels).unsqueeze(0),
x,
]
.sum(2)
.transpose(0, 1)
)
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds, device=logits.device)
x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]]
else:
x = F.one_hot(x, self.num_categories) # (F, C, B, 1, num_categories)
x = x.squeeze(dim=3) # (F, C, B, num_categories)
x = torch.einsum("fcbi,fkci->fbk", x.to(logits.dtype), logits)
idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1)
return x

def log_partition_function(self) -> Tensor:
Expand Down
22 changes: 11 additions & 11 deletions notebooks/compilation-options.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.57 s, sys: 950 ms, total: 5.52 s\n",
"Wall time: 5.45 s\n"
"CPU times: user 4.56 s, sys: 1.09 s, total: 5.66 s\n",
"Wall time: 5.56 s\n"
]
}
],
Expand Down Expand Up @@ -273,7 +273,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.2 s ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"1.17 s ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -338,8 +338,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.48 s, sys: 991 ms, total: 5.47 s\n",
"Wall time: 5.39 s\n"
"CPU times: user 4.65 s, sys: 967 ms, total: 5.62 s\n",
"Wall time: 5.52 s\n"
]
}
],
Expand Down Expand Up @@ -420,7 +420,7 @@
"id": "f074e168-dee4-4234-8eae-afd28fae317f",
"metadata": {},
"source": [
"As we see in the next code snippet, enabling folding provided an (approximately) **20x speed-up** for feed-forward circuit evaluations."
"As we see in the next code snippet, enabling folding provided an (approximately) **19.9x speed-up** for feed-forward circuit evaluations."
]
},
{
Expand All @@ -433,7 +433,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"60 ms ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"58.9 ms ± 6.43 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand Down Expand Up @@ -527,8 +527,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.81 s, sys: 876 ms, total: 5.68 s\n",
"Wall time: 5.61 s\n"
"CPU times: user 5.02 s, sys: 929 ms, total: 5.95 s\n",
"Wall time: 5.85 s\n"
]
}
],
Expand Down Expand Up @@ -591,7 +591,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"26.7 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"25.6 ms ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand All @@ -608,7 +608,7 @@
"id": "11d95c02-2c66-4414-b676-0dec303f2aa9",
"metadata": {},
"source": [
"Note that, we achieved an (approximately) **2.2x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **44.9x speed-up**, when compared to the circuit compiled with no folding and no optimizations."
"Note that, we achieved an (approximately) **2.3x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **45.7x speed-up**, when compared to the circuit compiled with no folding and no optimizations."
]
},
{
Expand Down

0 comments on commit 9dc089e

Please sign in to comment.