-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove Tket dependencies for PennyLane model #195
base: main
Are you sure you want to change the base?
Changes from all commits
08462a3
dee07d5
b8e14c7
94588f7
ea3dd27
e08ed62
a14bd84
a47a994
96e8a16
383c37c
9a41779
ecaa0f8
de6e7e7
b60d013
aebf473
2b1d692
9f0b144
3476fb8
df26e0f
4581628
8846e07
d75e524
a401b9d
4fc95d0
347f93d
243fa51
f779122
7e55448
e5bdbac
2bffca5
e60bf89
330537e
3a56070
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,16 +26,20 @@ | |
|
||
import numpy as np | ||
import pytket as tk | ||
from pytket.circuit import (Bit, Command, Op, OpType, Qubit) | ||
from pytket.circuit import (Bit, Command, Op, OpType) | ||
from pytket.utils import probs_from_counts | ||
import sympy | ||
from typing_extensions import Self | ||
|
||
from lambeq.backend import Functor, Symbol | ||
from lambeq.backend.quantum import (bit, Box, Bra, CCX, CCZ, Controlled, CRx, | ||
CRy, CRz, Daggered, Diagram, Discard, | ||
GATES, Id, Ket, Measure, quantum, qubit, | ||
Rx, Ry, Rz, Scalar, Swap, X, Y, Z) | ||
from lambeq.backend import Symbol | ||
from lambeq.backend.quantum import (bit, Box, Bra, CCX, CCZ, | ||
circuital_to_dict, | ||
Controlled, CRx, CRy, CRz, | ||
Diagram, Discard, GATES, Id, | ||
is_circuital, Ket, Measure, | ||
qubit, Rx, Ry, Rz, Scalar, Swap, | ||
to_circuital, X, Y, Z | ||
) | ||
|
||
OPTYPE_MAP = {'H': OpType.H, | ||
'X': OpType.X, | ||
|
@@ -52,7 +56,7 @@ | |
'CRy': OpType.CRy, | ||
'CRz': OpType.CRz, | ||
'CCX': OpType.CCX, | ||
'Swap': OpType.SWAP} | ||
'SWAP': OpType.SWAP} | ||
|
||
|
||
class Circuit(tk.Circuit): | ||
|
@@ -192,161 +196,6 @@ def get_counts(self, | |
return counts | ||
|
||
|
||
def to_tk(circuit: Diagram): | ||
""" | ||
Takes a :py:class:`lambeq.quantum.Diagram`, returns | ||
a :py:class:`Circuit`. | ||
""" | ||
# bits and qubits are lists of register indices, at layer i we want | ||
# len(bits) == circuit[:i].cod.count(bit) and same for qubits | ||
tk_circ = Circuit() | ||
bits: list[int] = [] | ||
qubits: list[int] = [] | ||
circuit = circuit.init_and_discard() | ||
|
||
def remove_ketbra1(_, box: Box) -> Diagram | Box: | ||
ob_map: dict[Box, Diagram] | ||
ob_map = {Ket(1): Ket(0) >> X, # type: ignore[dict-item] | ||
Bra(1): X >> Bra(0)} # type: ignore[dict-item] | ||
return ob_map.get(box, box) | ||
|
||
def prepare_qubits(qubits: list[int], | ||
box: Box, | ||
offset: int) -> list[int]: | ||
renaming = dict() | ||
start = (tk_circ.n_qubits if not qubits else 0 | ||
if not offset else qubits[offset - 1] + 1) | ||
for i in range(start, tk_circ.n_qubits): | ||
old = Qubit('q', i) | ||
new = Qubit('q', i + len(box.cod)) | ||
renaming.update({old: new}) | ||
tk_circ.rename_units(renaming) | ||
tk_circ.add_blank_wires(len(box.cod)) | ||
return (qubits[:offset] + list(range(start, start + len(box.cod))) | ||
+ [i + len(box.cod) for i in qubits[offset:]]) | ||
|
||
def measure_qubits(qubits: list[int], | ||
bits: list[int], | ||
box: Box, | ||
bit_offset: int, | ||
qubit_offset: int) -> tuple[list[int], list[int]]: | ||
if isinstance(box, Bra): | ||
tk_circ.post_select({len(tk_circ.bits): box.bit}) | ||
for j, _ in enumerate(box.dom): | ||
i_bit, i_qubit = len(tk_circ.bits), qubits[qubit_offset + j] | ||
offset = len(bits) if isinstance(box, Measure) else None | ||
tk_circ.add_bit(Bit(i_bit), offset=offset) | ||
tk_circ.Measure(i_qubit, i_bit) | ||
if isinstance(box, Measure): | ||
bits = bits[:bit_offset + j] + [i_bit] + bits[bit_offset + j:] | ||
# remove measured qubits | ||
qubits = (qubits[:qubit_offset] | ||
+ qubits[qubit_offset + len(box.dom):]) | ||
return bits, qubits | ||
|
||
def swap(i: int, j: int, unit_factory=Qubit) -> None: | ||
old, tmp, new = ( | ||
unit_factory(i), unit_factory('tmp', 0), unit_factory(j)) | ||
tk_circ.rename_units({old: tmp}) | ||
tk_circ.rename_units({new: old}) | ||
tk_circ.rename_units({tmp: new}) | ||
|
||
def add_gate(qubits: list[int], box: Box, offset: int) -> None: | ||
|
||
is_dagger = False | ||
if isinstance(box, Daggered): | ||
box = box.dagger() | ||
is_dagger = True | ||
|
||
i_qubits = [qubits[offset + j] for j in range(len(box.dom))] | ||
|
||
if isinstance(box, (Rx, Ry, Rz)): | ||
phase = box.phase | ||
if isinstance(box.phase, Symbol): | ||
# Tket uses sympy, lambeq uses custom symbol | ||
phase = box.phase.to_sympy() | ||
op = Op.create(OPTYPE_MAP[box.name[:2]], 2 * phase) | ||
elif isinstance(box, Controlled): | ||
# The following works only for controls on single qubit gates | ||
|
||
# reverse the distance order | ||
dists = [] | ||
curr_box: Box | Controlled = box | ||
while isinstance(curr_box, Controlled): | ||
dists.append(curr_box.distance) | ||
curr_box = curr_box.controlled | ||
dists.reverse() | ||
|
||
# Index of the controlled qubit is the last entry in rel_idx | ||
rel_idx = [0] | ||
for dist in dists: | ||
if dist > 0: | ||
# Add control to the left, offset by distance | ||
rel_idx = [0] + [i + dist for i in rel_idx] | ||
else: | ||
# Add control to the right, don't offset | ||
right_most_idx = max(rel_idx) | ||
rel_idx.insert(-1, right_most_idx - dist) | ||
|
||
i_qubits = [i_qubits[i] for i in rel_idx] | ||
|
||
name = box.name.split('(')[0] | ||
if box.name in ('CX', 'CZ', 'CCX'): | ||
op = Op.create(OPTYPE_MAP[name]) | ||
elif name in ('CRx', 'CRz'): | ||
phase = box.phase | ||
if isinstance(box.phase, Symbol): | ||
# Tket uses sympy, lambeq uses custom symbol | ||
phase = box.phase.to_sympy() | ||
|
||
op = Op.create(OPTYPE_MAP[name], 2 * phase) | ||
elif name in ('CCX'): | ||
op = Op.create(OPTYPE_MAP[name]) | ||
elif box.name in OPTYPE_MAP: | ||
op = Op.create(OPTYPE_MAP[box.name]) | ||
else: | ||
raise NotImplementedError(box) | ||
|
||
if is_dagger: | ||
op = op.dagger | ||
|
||
tk_circ.add_gate(op, i_qubits) | ||
|
||
circuit = Functor(target_category=quantum, # type: ignore [assignment] | ||
ob=lambda _, x: x, | ||
ar=remove_ketbra1)(circuit) # type: ignore [arg-type] | ||
for left, box, _ in circuit: | ||
if isinstance(box, Ket): | ||
qubits = prepare_qubits(qubits, box, left.count(qubit)) | ||
elif isinstance(box, (Measure, Bra)): | ||
bits, qubits = measure_qubits( | ||
qubits, bits, box, left.count(bit), left.count(qubit)) | ||
elif isinstance(box, Discard): | ||
qubits = (qubits[:left.count(qubit)] | ||
+ qubits[left.count(qubit) + box.dom.count(qubit):]) | ||
elif isinstance(box, Swap): | ||
if box == Swap(qubit, qubit): | ||
off = left.count(qubit) | ||
swap(qubits[off], qubits[off + 1]) | ||
elif box == Swap(bit, bit): | ||
off = left.count(bit) | ||
if tk_circ.post_processing: | ||
right = Id(tk_circ.post_processing.cod[off + 2:]) | ||
tk_circ.post_process( | ||
Id(bit ** off) @ Swap(bit, bit) @ right) | ||
else: | ||
swap(bits[off], bits[off + 1], unit_factory=Bit) | ||
else: # pragma: no cover | ||
continue # bits and qubits live in different registers. | ||
elif isinstance(box, Scalar): | ||
tk_circ.scale(abs(box.array) ** 2) | ||
elif isinstance(box, Box): | ||
add_gate(qubits, box, left.count(qubit)) | ||
else: # pragma: no cover | ||
raise NotImplementedError | ||
return tk_circ | ||
|
||
|
||
def _tk_to_lmbq_param(theta): | ||
if not isinstance(theta, sympy.Expr): | ||
return theta | ||
|
@@ -362,6 +211,71 @@ def _tk_to_lmbq_param(theta): | |
raise ValueError('Parameter must be a (possibly scaled) sympy Symbol') | ||
|
||
|
||
def to_tk(diagram: Diagram): | ||
"""Takes a :py:class:`lambeq.quantum.Diagram`, returns | ||
a :class:`lambeq.backend.converters.tk.Circuit` | ||
for t|ket>. | ||
|
||
|
||
Parameters | ||
---------- | ||
diagram : :py:class:`~lambeq.backend.quantum.Diagram` | ||
The :py:class:`Circuits <lambeq.backend.quantum.Diagram>` | ||
to be converted to a tket circuit. | ||
|
||
Returns | ||
------- | ||
tk_circuit : lambeq.backend.quantum | ||
A :class:`lambeq.backend.converters.tk.Circuit`. | ||
|
||
Note | ||
---- | ||
* Converts to circuital. | ||
* Copies the diagram to avoid modifying the original. | ||
""" | ||
|
||
if not is_circuital(diagram): | ||
diagram = to_circuital(diagram) | ||
|
||
circuit_dict = circuital_to_dict(diagram) | ||
|
||
post_select = {postselect['qubit']: postselect['phase'] | ||
for postselect in | ||
circuit_dict['measurements']['post']} | ||
|
||
circuit = Circuit(circuit_dict['qubits']['total'], | ||
len(circuit_dict['qubits']['bitmap']), | ||
post_selection=post_select | ||
) | ||
|
||
for gate in circuit_dict['gates']: | ||
|
||
if gate['type'] == 'Scalar': | ||
circuit.scale(abs(gate['phase'])**2) | ||
continue | ||
elif not gate['type'] in OPTYPE_MAP: | ||
raise NotImplementedError(f'Gate {gate} not supported') | ||
|
||
if 'phase' in gate and gate['phase']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
op = Op.create(OPTYPE_MAP[gate['type']], 2 * gate['phase']) | ||
else: | ||
op = Op.create(OPTYPE_MAP[gate['type']]) | ||
|
||
if gate['dagger']: | ||
op = op.dagger | ||
|
||
qubits = gate['qubits'] | ||
circuit.add_gate(op, qubits) | ||
|
||
for measure in circuit_dict['measurements']['measure']: | ||
circuit.Measure(measure['qubit'], measure['bit']) | ||
|
||
for postselect in circuit_dict['measurements']['post']: | ||
circuit.Measure(postselect['qubit'], postselect['bit']) | ||
|
||
return circuit | ||
|
||
|
||
def from_tk(tk_circuit: tk.Circuit) -> Diagram: | ||
"""Translates from tket to a lambeq Diagram.""" | ||
tk_circ: Circuit = Circuit.upgrade(tk_circuit) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -203,6 +203,62 @@ def __getitem__(self, index: int | slice) -> Self: | |
else: | ||
return self._fromiter(objects[index]) | ||
|
||
def replace(self, other: Self, index: int) -> Self: | ||
"""Replace a type at the specified index in the complex type list. | ||
|
||
Parameters | ||
---------- | ||
other : Ty | ||
The type to insert. Can be atomic or complex. | ||
index : int | ||
The position where the type should be inserted. | ||
""" | ||
if not (index <= len(self) and index >= 0): | ||
raise IndexError(f'Index {index} out of bounds for ' | ||
f'type {self} with length {len(self)}.') | ||
|
||
if self.is_empty: | ||
return other | ||
else: | ||
objects = self.objects.copy() | ||
|
||
if len(objects) == 1: | ||
return other | ||
|
||
if index == 0: | ||
objects = [*other] + objects[1:] | ||
elif index == len(self): | ||
objects = objects[:-1] + [*other] | ||
else: | ||
objects = objects[:index] + [*other] + objects[index+1:] | ||
|
||
return self._fromiter(objects) | ||
|
||
def insert(self, other: Self, index: int) -> Self: | ||
"""Insert a type at the specified index in the complex type list. | ||
|
||
Parameters | ||
---------- | ||
other : Ty | ||
The type to insert. Can be atomic or complex. | ||
index : int | ||
The position where the type should be inserted. | ||
""" | ||
if not (index <= len(self)): | ||
raise IndexError(f'Index {index} out of bounds for ' | ||
f'type {self} with length {len(self)}.') | ||
|
||
if self.is_empty: | ||
return other | ||
else: | ||
if index == 0: | ||
return other @ self | ||
elif index == len(self): | ||
return self @ other | ||
objects = self.objects.copy() | ||
objects = objects[:index] + [*other] + objects[index:] | ||
return self._fromiter(objects) | ||
|
||
@classmethod | ||
def _fromiter(cls, objects: Iterable[Self]) -> Self: | ||
"""Create a Ty from an iterable of atomic objects.""" | ||
|
@@ -970,8 +1026,10 @@ def then(self, *diagrams: Diagrammable) -> Self: | |
cod = self.cod | ||
for n, diagram in enumerate(diags): | ||
if diagram.dom != cod: | ||
raise ValueError(f'Diagram {n} (cod={cod}) does not compose ' | ||
f'with diagram {n+1} (dom={diagram.dom})') | ||
raise ValueError(f'Diagram {n} ' | ||
f'(cod={cod.__repr__()}) ' | ||
f'does not compose with diagram {n+1} ' | ||
f'(dom={diagram.dom.__repr__()})') | ||
cod = diagram.cod | ||
|
||
layers.extend(diagram.layers) | ||
|
@@ -1912,14 +1970,14 @@ class Functor: | |
>>> n = Ty('n') | ||
>>> diag = Cap(n, n.l) @ Id(n) >> Id(n) @ Cup(n.l, n) | ||
>>> diag.draw( | ||
... figsize=(2, 2), path='./snake.png') | ||
... figsize=(2, 2), path='./docs/_static/images/snake.png') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Be sure you pull the main again, this file here looks outdated (or some problem happened with conflict resolution). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, be sure to keep the upstream changes for these images above. |
||
|
||
.. image:: ./_static/images/snake.png | ||
:align: center | ||
|
||
>>> F = Functor(grammar, lambda _, ty : ty @ ty) | ||
>>> F(diag).draw( | ||
... figsize=(2, 2), path='./snake-2.png') | ||
... figsize=(2, 2), path='./docs/_static/images/snake-2.png') | ||
|
||
.. image:: ./_static/images/snake-2.png | ||
:align: center | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No return type here?