From 1252b786ca4e3b2edab0a6340753b28e98801b34 Mon Sep 17 00:00:00 2001 From: loreloc Date: Mon, 11 Nov 2024 11:37:23 +0000 Subject: [PATCH 1/2] removed previous categorical one hot impl and simplified the index impl --- cirkit/backend/torch/layers/input.py | 34 ++++++++-------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/cirkit/backend/torch/layers/input.py b/cirkit/backend/torch/layers/input.py index 9af0b2c4..8b9bdc99 100644 --- a/cirkit/backend/torch/layers/input.py +++ b/cirkit/backend/torch/layers/input.py @@ -296,9 +296,6 @@ def __init__( raise ValueError(f"The number of folds and shape of 'probs' must match the layer's") self.probs = probs self.logits = logits - self.idx_mode = ( - len(torch.unique(self.scope_idx)) > 4096 or self.num_categories > 256 - ) def _valid_parameter_shape(self, p: TorchParameter) -> bool: if p.num_folds != self.num_folds: @@ -326,30 +323,17 @@ def params(self) -> Mapping[str, TorchParameter]: def log_unnormalized_likelihood(self, x: Tensor) -> Tensor: if x.is_floating_point(): x = x.long() # The input to Categorical should be discrete + # x: (F, C, B, 1) -> (F, C, B) + x = x.squeeze(dim=3) + # logits: (F, K, C, N) logits = torch.log(self.probs()) if self.logits is None else self.logits() - if self.idx_mode: - if self.num_channels == 1: - x = ( - logits[:, :, 0, :] - .transpose(1, 2)[range(self.num_folds), x[:, 0, :, 0].t()] - .transpose(0, 1) - ) - else: - x = x[..., 0].permute(2, 0, 1) - x = ( - logits[ - torch.arange(self.num_folds).unsqueeze(1), - :, - torch.arange(self.num_channels).unsqueeze(0), - x, - ] - .sum(2) - .transpose(0, 1) - ) + if self.num_channels == 1: + idx_fold = torch.arange(self.num_folds, device=logits.device) + x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]] else: - x = F.one_hot(x, self.num_categories) # (F, C, B, 1, num_categories) - x = x.squeeze(dim=3) # (F, C, B, num_categories) - x = torch.einsum("fcbi,fkci->fbk", x.to(logits.dtype), logits) + idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None] + idx_channel = torch.arange(self.num_channels)[None, :, None] + x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1) return x def log_partition_function(self) -> Tensor: From 4bb46177476ee3619beb563949d5a20ce6fa9340 Mon Sep 17 00:00:00 2001 From: loreloc Date: Mon, 11 Nov 2024 11:43:54 +0000 Subject: [PATCH 2/2] re-run compilation-options.ipynb --- notebooks/compilation-options.ipynb | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/notebooks/compilation-options.ipynb b/notebooks/compilation-options.ipynb index c9cf0adf..c469db4c 100644 --- a/notebooks/compilation-options.ipynb +++ b/notebooks/compilation-options.ipynb @@ -203,8 +203,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 4.57 s, sys: 950 ms, total: 5.52 s\n", - "Wall time: 5.45 s\n" + "CPU times: user 4.56 s, sys: 1.09 s, total: 5.66 s\n", + "Wall time: 5.56 s\n" ] } ], @@ -273,7 +273,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1.2 s ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + "1.17 s ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], @@ -338,8 +338,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 4.48 s, sys: 991 ms, total: 5.47 s\n", - "Wall time: 5.39 s\n" + "CPU times: user 4.65 s, sys: 967 ms, total: 5.62 s\n", + "Wall time: 5.52 s\n" ] } ], @@ -420,7 +420,7 @@ "id": "f074e168-dee4-4234-8eae-afd28fae317f", "metadata": {}, "source": [ - "As we see in the next code snippet, enabling folding provided an (approximately) **20x speed-up** for feed-forward circuit evaluations." + "As we see in the next code snippet, enabling folding provided an (approximately) **19.9x speed-up** for feed-forward circuit evaluations." ] }, { @@ -433,7 +433,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "60 ms ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "58.9 ms ± 6.43 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -527,8 +527,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 4.81 s, sys: 876 ms, total: 5.68 s\n", - "Wall time: 5.61 s\n" + "CPU times: user 5.02 s, sys: 929 ms, total: 5.95 s\n", + "Wall time: 5.85 s\n" ] } ], @@ -591,7 +591,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "26.7 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "25.6 ms ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], @@ -608,7 +608,7 @@ "id": "11d95c02-2c66-4414-b676-0dec303f2aa9", "metadata": {}, "source": [ - "Note that, we achieved an (approximately) **2.2x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **44.9x speed-up**, when compared to the circuit compiled with no folding and no optimizations." + "Note that, we achieved an (approximately) **2.3x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **45.7x speed-up**, when compared to the circuit compiled with no folding and no optimizations." ] }, {