Skip to content

Commit

Permalink
Add mixed-scalar support to PennyLaneCircuit (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilkhatri authored Nov 29, 2024
1 parent de3f538 commit e219768
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
7 changes: 4 additions & 3 deletions lambeq/backend/converters/tk.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ def to_tk(circuit: Diagram):
qubits: list[int] = []
circuit = circuit.init_and_discard()

def remove_ket1(_, box: Box) -> Diagram | Box:
def remove_ketbra1(_, box: Box) -> Diagram | Box:
ob_map: dict[Box, Diagram]
ob_map = {Ket(1): Ket(0) >> X} # type: ignore[dict-item]
ob_map = {Ket(1): Ket(0) >> X, # type: ignore[dict-item]
Bra(1): X >> Bra(0)} # type: ignore[dict-item]
return ob_map.get(box, box)

def prepare_qubits(qubits: list[int],
Expand Down Expand Up @@ -313,7 +314,7 @@ def add_gate(qubits: list[int], box: Box, offset: int) -> None:

circuit = Functor(target_category=quantum, # type: ignore [assignment]
ob=lambda _, x: x,
ar=remove_ket1)(circuit) # type: ignore [arg-type]
ar=remove_ketbra1)(circuit) # type: ignore [arg-type]
for left, box, _ in circuit:
if isinstance(box, Ket):
qubits = prepare_qubits(qubits, box, left.count(qubit))
Expand Down
26 changes: 21 additions & 5 deletions lambeq/backend/pennylane.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@
from __future__ import annotations

from itertools import product
import sys
from typing import TYPE_CHECKING

import pennylane as qml
from pytket import OpType
import sympy
import torch

from lambeq.backend.quantum import Scalar
from lambeq.backend.quantum import Measure, Scalar

if TYPE_CHECKING:
from lambeq.backend.quantum import Diagram
Expand Down Expand Up @@ -216,9 +217,16 @@ def to_pennylane(lambeq_circuit: Diagram, probabilities=False,
The PennyLane circuit equivalent to the input lambeq circuit.
"""
if lambeq_circuit.is_mixed:
raise ValueError('Only pure quantum circuits are currently '
'supported.')

if any(isinstance(box, Measure) for box in lambeq_circuit.boxes):
raise ValueError('Only pure circuits, or circuits with discards'
' are currently supported.')

if lambeq_circuit.is_mixed and lambeq_circuit.cod:
# Some qubits discarded, some left open
print('Warning: Circuit includes both discards and open codomain'
' wires. All open wires will be discarded during conversion',
file=sys.stderr)

tk_circ = lambeq_circuit.to_tk()
op_list, params_list, wires_list, symbols_set = (
Expand All @@ -238,6 +246,7 @@ def to_pennylane(lambeq_circuit: Diagram, probabilities=False,
wires_list,
probabilities,
post_selection,
lambeq_circuit.is_mixed,
scalar,
tk_circ.n_qubits,
backend_config,
Expand All @@ -252,14 +261,15 @@ class PennyLaneCircuit:
"""Implement a pennylane circuit with post-selection."""

def __init__(self, ops, symbols, params, wires, probabilities,
post_selection, scale, n_qubits, backend_config,
post_selection, mixed, scale, n_qubits, backend_config,
diff_method):
self._ops = ops
self._symbols = symbols
self._params = params
self._wires = wires
self._probabilities = probabilities
self._post_selection = post_selection
self._mixed = mixed
self._scale = scale
self._n_qubits = n_qubits
self._backend_config = backend_config
Expand Down Expand Up @@ -400,6 +410,8 @@ def circuit(circ_params):
for op, params, wires in zip(self._ops, circ_params, self._wires):
op(*[2 * torch.pi * p for p in params], wires=wires)

if self._mixed:
return qml.density_matrix(self._post_selection.keys())
if self._probabilities:
return qml.probs(wires=range(self._n_qubits))
else:
Expand All @@ -424,6 +436,10 @@ def post_selected_circuit(self, params):
"""
states = self._circuit(params)

if self._mixed:
# Select the all-zeros subsystem
return states[0][0]

open_wires = self._n_qubits - len(self._post_selection)
post_selected_states = states[self._valid_states]
post_selected_states *= (self._scale ** 2 if self._probabilities
Expand Down
10 changes: 10 additions & 0 deletions tests/backend/test_pennylane_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ def test_pennylane_circuit_mixed_error():
snake.to_pennylane()


def test_pennylane_circuit_mixed_warning(capsys):
bell_state = Diagram.caps(qubit, qubit)
bell_discarded = bell_state >> Discard() @ Id(qubit)
_ = bell_discarded.to_pennylane()
captured = capsys.readouterr()
assert captured.err == ('Warning: Circuit includes both discards and open '
'codomain wires. All open wires will be discarded '
'during conversion\n')


def test_pennylane_circuit_draw(capsys):
bell_state = Diagram.caps(qubit, qubit)
bell_effect = bell_state[::-1]
Expand Down

0 comments on commit e219768

Please sign in to comment.