Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix symbolic call graphs for factoring phase estimates #1497

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,9 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
def build_composite_bloq(
self, bb: 'BloqBuilder', x: 'Soquet', y: 'Soquet', target: 'Soquet'
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

cvs: Union[list[int], HasLength]
if isinstance(self.bitsize, int):
cvs = [0] * self.bitsize
Expand Down Expand Up @@ -1151,6 +1154,8 @@ def wire_symbol(
def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet', target: 'Soquet'
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

if isinstance(self.dtype, QInt):
a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a)
Expand Down Expand Up @@ -1360,6 +1365,9 @@ def build_composite_bloq(
c: Optional['Soquet'] = None,
target: Optional['Soquet'] = None,
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

if self.uncompute:
# Uncompute
assert c is not None
Expand Down
5 changes: 5 additions & 0 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
DecomposeTypeError,
QBit,
QInt,
QUInt,
Expand All @@ -37,6 +38,7 @@
from qualtran.bloqs.mcmt.and_bloq import And
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.simulation.classical_sim import add_ints
from qualtran.symbolics.types import is_symbolic

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -134,6 +136,9 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet'
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.a_dtype.bitsize, self.b_dtype.bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `bitsize`.")

a_arr = bb.split(a)
ctrl_q = bb.split(ctrl)[0]
ancilla_arr = []
Expand Down
39 changes: 34 additions & 5 deletions qualtran/bloqs/factoring/_factoring_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,29 @@
# limitations under the License.

from functools import cached_property
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import numpy as np
import sympy
from attrs import frozen

from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature
from qualtran import (
Bloq,
BloqBuilder,
DecomposeTypeError,
QBit,
QUInt,
Register,
Side,
Signature,
Soquet,
SoquetT,
)
from qualtran.bloqs.basic_gates._shims import Measure
from qualtran.bloqs.qft import QFTTextBook
from qualtran.drawing import RarrowTextBox, Text, WireSymbol
from qualtran.symbolics import SymbolicInt
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.symbolics.types import SymbolicInt


@frozen
Expand All @@ -30,8 +46,21 @@ class MeasureQFT(Bloq):
def signature(self) -> 'Signature':
return Signature([Register('x', QBit(), shape=(self.n,), side=Side.LEFT)])

def decompose_bloq(self) -> 'CompositeBloq':
raise DecomposeTypeError('MeasureQFT is a placeholder, atomic bloq.')
def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this shim has a decomposition now should it become a bloq ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave that up to @mpharrigan. We added this decomposition a few PRs back (or meant to), but it got clobbered when merging two separate PRs touching this file (my mistake). I think the goal eventually is to use the phase_estimation bloqs, but for now I just added this so it could decompose entirely into leaf bloqs.

if isinstance(self.n, sympy.Expr):
raise DecomposeTypeError("Cannot decompose symbolic `n`.")

x = bb.join(np.array(x), dtype=QUInt(self.n))
x = bb.add(QFTTextBook(self.n), q=x)
x = bb.split(x)

for i in range(self.n):
bb.add(Measure(), q=x[i])

return {}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {QFTTextBook(self.n): 1, Measure(): self.n}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
Expand Down
94 changes: 48 additions & 46 deletions qualtran/bloqs/factoring/ecc/ec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics.types import HasLength, is_symbolic
from qualtran.symbolics.types import HasLength, is_symbolic, SymbolicInt

from .ec_point import ECPoint

Expand Down Expand Up @@ -80,8 +80,8 @@ class _ECAddStepOne(Bloq):
Fig 10.
"""

n: int
mod: int
n: 'SymbolicInt'
mod: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -214,9 +214,9 @@ class _ECAddStepTwo(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -251,7 +251,9 @@ def on_classical_vals(
f1 = 0
else:
lam = QMontgomeryUInt(self.n).montgomery_product(
int(y), QMontgomeryUInt(self.n).montgomery_inverse(int(x), self.mod), self.mod
int(y),
QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)),
int(self.mod),
)
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit
# which flips f1 when lam and lam_r are equal.
Expand Down Expand Up @@ -299,7 +301,7 @@ def build_composite_bloq(
# If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p.
z4_split = bb.split(z4)
lam_split = bb.split(lam)
for i in range(self.n):
for i in range(int(self.n)):
ctrls = [f1, ctrl, z4_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[0, 1, 1]), controls=ctrls, target=lam_split[i]
Expand All @@ -311,7 +313,7 @@ def build_composite_bloq(

# If ctrl = 1 and x = a: lam = lam_r.
lam_r_split = bb.split(lam_r)
for i in range(self.n):
for i in range(int(self.n)):
ctrls = [f1, ctrl, lam_r_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i]
Expand Down Expand Up @@ -383,9 +385,9 @@ class _ECAddStepThree(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -455,7 +457,7 @@ def build_composite_bloq(
z1 = bb.add(IntState(bitsize=self.n, val=0))
a_split = bb.split(a)
z1_split = bb.split(z1)
for i in range(self.n):
for i in range(int(self.n)):
a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i])
a = bb.join(a_split, QMontgomeryUInt(self.n))
z1 = bb.join(z1_split, QMontgomeryUInt(self.n))
Expand All @@ -472,7 +474,7 @@ def build_composite_bloq(
z1 = bb.add(ModDbl(QMontgomeryUInt(self.n), mod=self.mod).adjoint(), x=z1)
a_split = bb.split(a)
z1_split = bb.split(z1)
for i in range(self.n):
for i in range(int(self.n)):
a_split[i], z1_split[i] = bb.add(CNOT(), ctrl=a_split[i], target=z1_split[i])
a = bb.join(a_split, QMontgomeryUInt(self.n))
z1 = bb.join(z1_split, QMontgomeryUInt(self.n))
Expand Down Expand Up @@ -520,9 +522,9 @@ class _ECAddStepFour(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand All @@ -538,10 +540,10 @@ def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
x = (
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), self.mod)
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), int(self.mod))
) % self.mod
if lam > 0:
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), self.mod)
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), int(self.mod))
return {'x': x, 'y': y, 'lam': lam}

def build_composite_bloq(
Expand All @@ -554,7 +556,7 @@ def build_composite_bloq(
z4 = bb.add(IntState(bitsize=self.n, val=0))
lam_split = bb.split(lam)
z4_split = bb.split(z4)
for i in range(self.n):
for i in range(int(self.n)):
lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i])
lam = bb.join(lam_split, QMontgomeryUInt(self.n))
z4 = bb.join(z4_split, QMontgomeryUInt(self.n))
Expand Down Expand Up @@ -584,7 +586,7 @@ def build_composite_bloq(
)
lam_split = bb.split(lam)
z4_split = bb.split(z4)
for i in range(self.n):
for i in range(int(self.n)):
lam_split[i], z4_split[i] = bb.add(CNOT(), ctrl=lam_split[i], target=z4_split[i])
lam = bb.join(lam_split, QMontgomeryUInt(self.n))
z4 = bb.join(z4_split, QMontgomeryUInt(self.n))
Expand All @@ -602,7 +604,7 @@ def build_composite_bloq(
# y = y_r + b % p.
z3_split = bb.split(z3)
y_split = bb.split(y)
for i in range(self.n):
for i in range(int(self.n)):
z3_split[i], y_split[i] = bb.add(CNOT(), ctrl=z3_split[i], target=y_split[i])
z3 = bb.join(z3_split, QMontgomeryUInt(self.n))
y = bb.join(y_split, QMontgomeryUInt(self.n))
Expand Down Expand Up @@ -659,9 +661,9 @@ class _ECAddStepFive(Bloq):
Fig 10.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -720,7 +722,7 @@ def build_composite_bloq(
# If ctrl: lam = 0.
z4_split = bb.split(z4)
lam_split = bb.split(lam)
for i in range(self.n):
for i in range(int(self.n)):
ctrls = [ctrl, z4_split[i]]
ctrls, lam_split[i] = bb.add(
MultiControlX(cvs=[1, 1]), controls=ctrls, target=lam_split[i]
Expand Down Expand Up @@ -801,8 +803,8 @@ class _ECAddStepSix(Bloq):
Fig 10.
"""

n: int
mod: int
n: 'SymbolicInt'
mod: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -866,7 +868,7 @@ def build_composite_bloq(
# Set (x, y) to (a, b) if f4 is set.
a_split = bb.split(a)
x_split = bb.split(x)
for i in range(self.n):
for i in range(int(self.n)):
toff_ctrl = [f4, a_split[i]]
toff_ctrl, x_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=x_split[i])
f4 = toff_ctrl[0]
Expand All @@ -875,7 +877,7 @@ def build_composite_bloq(
x = bb.join(x_split, QMontgomeryUInt(self.n))
b_split = bb.split(b)
y_split = bb.split(y)
for i in range(self.n):
for i in range(int(self.n)):
toff_ctrl = [f4, b_split[i]]
toff_ctrl, y_split[i] = bb.add(Toffoli(), ctrl=toff_ctrl, target=y_split[i])
f4 = toff_ctrl[0]
Expand All @@ -888,11 +890,11 @@ def build_composite_bloq(
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
ab_split = bb.split(ab)
a = bb.join(ab_split[: self.n], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[self.n :], dtype=QMontgomeryUInt(self.n))
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
xy_split = bb.split(xy)
x = bb.join(xy_split[: self.n], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[self.n :], dtype=QMontgomeryUInt(self.n))
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))

# Unset f3 if (a, b) = (0, 0).
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
Expand Down Expand Up @@ -1000,9 +1002,9 @@ class ECAdd(Bloq):
Litinski. 2023. Fig 5.
"""

n: int
mod: int
window_size: int = 1
n: 'SymbolicInt'
mod: 'SymbolicInt'
window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down Expand Up @@ -1070,29 +1072,29 @@ def build_composite_bloq(

def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
curve_a = (
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, self.mod)
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, int(self.mod))
* 2
* QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod)
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod) ** 2)
* QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod))
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)) ** 2)
) % self.mod
p1 = ECPoint(
QMontgomeryUInt(self.n).montgomery_to_uint(a, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(b, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)),
QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod)),
mod=self.mod,
curve_a=curve_a,
)
p2 = ECPoint(
QMontgomeryUInt(self.n).montgomery_to_uint(x, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(y, self.mod),
QMontgomeryUInt(self.n).montgomery_to_uint(x, int(self.mod)),
QMontgomeryUInt(self.n).montgomery_to_uint(y, int(self.mod)),
mod=self.mod,
curve_a=curve_a,
)
result = p1 + p2
return {
'a': a,
'b': b,
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, self.mod),
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, self.mod),
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, int(self.mod)),
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, int(self.mod)),
'lam_r': lam_r,
}

Expand Down
Loading
Loading