Skip to content

Commit

Permalink
Add new symbol class to Lambeq (#191)
Browse files Browse the repository at this point in the history
Replace sympy symbols in the backend with a lambeq-native class.
  • Loading branch information
blakewilsonquantinuum authored Nov 26, 2024
1 parent 9ce7774 commit de3f538
Show file tree
Hide file tree
Showing 25 changed files with 358 additions and 178 deletions.
4 changes: 3 additions & 1 deletion lambeq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'SpiderAnsatz',
'StronglyEntanglingAnsatz',
'Symbol',
'lambdify',
'TensorAnsatz',

'CCGType',
Expand Down Expand Up @@ -104,10 +105,11 @@
'MSELoss',
]

from lambeq.backend import Symbol, lambdify
from lambeq import ansatz, core, rewrite, text2diagram, tokeniser, training
from lambeq.ansatz import (BaseAnsatz, CircuitAnsatz, IQPAnsatz, MPSAnsatz,
Sim14Ansatz, Sim15Ansatz, Sim4Ansatz, SpiderAnsatz,
StronglyEntanglingAnsatz, Symbol, TensorAnsatz)
StronglyEntanglingAnsatz, TensorAnsatz)
from lambeq.core.globals import VerbosityLevel
from lambeq.core.types import AtomicType
from lambeq.rewrite import (CoordinationRewriteRule, CurryRewriteRule,
Expand Down
4 changes: 2 additions & 2 deletions lambeq/ansatz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

__all__ = ['BaseAnsatz', 'CircuitAnsatz', 'IQPAnsatz', 'MPSAnsatz',
'Sim14Ansatz', 'Sim15Ansatz', 'Sim4Ansatz', 'SpiderAnsatz',
'StronglyEntanglingAnsatz', 'Symbol', 'TensorAnsatz']
'StronglyEntanglingAnsatz', 'TensorAnsatz']

from lambeq.ansatz.base import BaseAnsatz, Symbol
from lambeq.ansatz.base import BaseAnsatz
from lambeq.ansatz.circuit import (CircuitAnsatz, IQPAnsatz,
Sim14Ansatz, Sim15Ansatz, Sim4Ansatz,
StronglyEntanglingAnsatz)
Expand Down
65 changes: 1 addition & 64 deletions lambeq/ansatz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,72 +24,9 @@

from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, Literal

import sympy

from lambeq.backend import grammar, tensor


class Symbol(sympy.Symbol):
"""A sympy symbol augmented with extra information.
Attributes
----------
directed_dom : int
The size of the domain of the tensor-box that this symbol
represents.
directed_cod : int
The size of the codomain of the tensor-box that this symbol
represents.
size : int
The total size of the tensor that this symbol represents
(directed_dom * directed_cod).
"""
directed_dom: int
directed_cod: int

def __new__(cls,
name: str,
directed_dom: int = 1,
directed_cod: int = 1,
**assumptions: bool) -> Symbol:
"""Initialise a symbol.
Parameters
----------
directed_dom : int, default: 1
The size of the domain of the tensor-box that this symbol
represents.
directed_cod : int, default: 1
The size of the codomain of the tensor-box that this symbol
represents.
"""
cls._sanitize(assumptions, cls)

obj: Symbol = sympy.Symbol.__xnew__(cls, name, **assumptions)
obj.directed_dom = directed_dom
obj.directed_cod = directed_cod
return obj

def __getnewargs_ex__(self) -> tuple[tuple[str, int], dict[str, bool]]:
return (self.name, self.size), self.assumptions0

@property
def size(self) -> int:
return self.directed_dom * self.directed_cod

@sympy.cacheit
def sort_key(self, order: Literal[None] = None) -> tuple[Any, ...]:
return (self.class_key(),
(2, (self.name, self.size)),
sympy.S.One.sort_key(),
sympy.S.One)

def _hashable_content(self) -> tuple[Any, ...]:
return (*super()._hashable_content(), self.size)
from lambeq.backend.symbol import Symbol


class BaseAnsatz(ABC):
Expand Down
9 changes: 5 additions & 4 deletions lambeq/ansatz/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from typing import Type

import numpy as np
from sympy import Symbol, symbols

from lambeq.ansatz import BaseAnsatz
from lambeq.backend.grammar import Box, Diagram, Functor, Ty
Expand All @@ -52,6 +51,7 @@
Rotation,
Rx, Ry, Rz
)
from lambeq.backend.symbol import Symbol

computational_basis = Id(qubit)

Expand Down Expand Up @@ -132,14 +132,15 @@ def _ar(self, _: Functor, box: Box) -> Circuit:
if n_qubits == 0:
circuit = Id()
elif n_qubits == 1:
syms = symbols(f'{label}_0:{self.n_single_qubit_params}',
cls=Symbol)
syms = [Symbol(f'{label}_{i}')
for i in range(self.n_single_qubit_params)]
circuit = Id(qubit)
for rot, sym in zip(cycle(self.single_qubit_rotations), syms):
circuit >>= rot(sym)
else:
params_shape = self.params_shape(n_qubits)
syms = symbols(f'{label}_0:{np.prod(params_shape)}', cls=Symbol)
syms = [Symbol(f'{label}_{i}')
for i in range(np.prod(params_shape))]
params: np.ndarray = np.array(syms).reshape(params_shape)
circuit = self.circuit(n_qubits, params)

Expand Down
4 changes: 2 additions & 2 deletions lambeq/ansatz/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from collections.abc import Mapping

from lambeq.ansatz import BaseAnsatz, Symbol
from lambeq.backend import grammar, tensor
from lambeq.ansatz import BaseAnsatz
from lambeq.backend import grammar, Symbol, tensor
from lambeq.backend.grammar import Cup, Spider, Ty, Word
from lambeq.backend.tensor import Dim

Expand Down
5 changes: 4 additions & 1 deletion lambeq/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@

'draw',
'draw_equation',
'to_gif']
'to_gif',
'Symbol',
'lambdify']

from lambeq.backend.grammar import (Box, Cap, Category, Cup, Diagram,
Frame, Functor, Id, Spider, Swap, Ty, Word)
from lambeq.backend.pregroup_tree import PregroupTreeNode
from lambeq.backend.symbol import lambdify, Symbol
from lambeq.backend.drawing import draw, draw_equation, to_gif
48 changes: 38 additions & 10 deletions lambeq/backend/converters/tk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
import pytket as tk
from pytket.circuit import (Bit, Command, Op, OpType, Qubit)
from pytket.utils import probs_from_counts
import sympy
from typing_extensions import Self

from lambeq.backend import Functor
from lambeq.backend import Functor, Symbol
from lambeq.backend.quantum import (bit, Box, Bra, CCX, CCZ, Controlled, CRx,
CRy, CRz, Daggered, Diagram, Discard,
GATES, Id, Ket, Measure, quantum, qubit,
Expand Down Expand Up @@ -259,7 +260,11 @@ def add_gate(qubits: list[int], box: Box, offset: int) -> None:
i_qubits = [qubits[offset + j] for j in range(len(box.dom))]

if isinstance(box, (Rx, Ry, Rz)):
op = Op.create(OPTYPE_MAP[box.name[:2]], 2 * box.phase)
phase = box.phase
if isinstance(box.phase, Symbol):
# Tket uses sympy, lambeq uses custom symbol
phase = box.phase.to_sympy()
op = Op.create(OPTYPE_MAP[box.name[:2]], 2 * phase)
elif isinstance(box, Controlled):
# The following works only for controls on single qubit gates

Expand Down Expand Up @@ -287,8 +292,13 @@ def add_gate(qubits: list[int], box: Box, offset: int) -> None:
name = box.name.split('(')[0]
if box.name in ('CX', 'CZ', 'CCX'):
op = Op.create(OPTYPE_MAP[name])
elif name in ('CRx', 'CRz'): # TODO Controlled rotations
op = Op.create(OPTYPE_MAP[name], 2 * box.phase)
elif name in ('CRx', 'CRz'):
phase = box.phase
if isinstance(box.phase, Symbol):
# Tket uses sympy, lambeq uses custom symbol
phase = box.phase.to_sympy()

op = Op.create(OPTYPE_MAP[name], 2 * phase)
elif name in ('CCX'):
op = Op.create(OPTYPE_MAP[name])
elif box.name in OPTYPE_MAP:
Expand Down Expand Up @@ -336,6 +346,21 @@ def add_gate(qubits: list[int], box: Box, offset: int) -> None:
return tk_circ


def _tk_to_lmbq_param(theta):
if not isinstance(theta, sympy.Expr):
return theta
elif isinstance(theta, sympy.Symbol):
return Symbol(theta.name)
elif isinstance(theta, sympy.Mul):
scale, symbol = theta.as_coeff_Mul()
if not isinstance(symbol, sympy.Symbol):
raise ValueError('Parameter must be a (possibly scaled) sympy'
'Symbol')
return Symbol(symbol.name, scale=scale)
else:
raise ValueError('Parameter must be a (possibly scaled) sympy Symbol')


def from_tk(tk_circuit: tk.Circuit) -> Diagram:
"""Translates from tket to a lambeq Diagram."""
tk_circ: Circuit = Circuit.upgrade(tk_circuit)
Expand All @@ -354,11 +379,11 @@ def box_and_offset_from_tk(tk_gate) -> tuple[Diagram, int]:

if len(tk_gate.args) == 1: # single qubit gate
if name == 'Rx':
box = Rx(tk_gate.op.params[0] / 2)
box = Rx(_tk_to_lmbq_param(tk_gate.op.params[0]) * 0.5)
elif name == 'Ry':
box = Ry(tk_gate.op.params[0] / 2)
box = Ry(_tk_to_lmbq_param(tk_gate.op.params[0]) * 0.5)
elif name == 'Rz':
box = Rz(tk_gate.op.params[0] / 2)
box = Rz(_tk_to_lmbq_param(tk_gate.op.params[0]) * 0.5)
elif name in GATES:
box = cast(Box, GATES[name])

Expand All @@ -370,11 +395,14 @@ def box_and_offset_from_tk(tk_gate) -> tuple[Diagram, int]:
offset += distance

if name == 'CRx':
box = CRx(tk_gate.op.params[0] / 2, distance)
box = CRx(
_tk_to_lmbq_param(tk_gate.op.params[0]) * 0.5, distance)
elif name == 'CRy':
box = CRy(tk_gate.op.params[0] / 2, distance)
box = CRy(
_tk_to_lmbq_param(tk_gate.op.params[0]) * 0.5, distance)
elif name == 'CRz':
box = CRz(tk_gate.op.params[0] / 2, distance)
box = CRz(
_tk_to_lmbq_param(tk_gate.op.params[0]) * 0.5, distance)
elif name == 'SWAP':
distance = abs(distance)
idx = list(range(distance + 1))
Expand Down
13 changes: 7 additions & 6 deletions lambeq/backend/pennylane.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def tk_op_to_pennylane(tk_op):
:class:`qml.operation.Operation`
The PennyLane operation equivalent to the input pytket Op.
list of (:class:`torch.FloatTensor` or
:class:`sympy.core.symbol.Symbol`)
:class:`lambeq.backend.symbol.Symbol`)
The parameters of the operation.
list of :class:`sympy.core.symbol.Symbol`
list of :class:`lambeq.backend.symbol.Symbol`
The free symbols in the parameters of the operation.
list of int
The wires/qubits to apply the operation to.
Expand Down Expand Up @@ -137,11 +137,11 @@ def extract_ops_from_tk(tk_circ):
list of :class:`qml.operation.Operation`
The PennyLane operations extracted from the pytket circuit.
list of list of (:class:`torch.FloatTensor` or
:class:`sympy.core.symbol.Symbol`)
:class:`lambeq.backend.symbol.Symbol`)
The corresponding parameters of the operations.
list of list of int
The corresponding wires of the operations.
set of :class:`sympy.core.symbol.Symbol`
set of :class:`lambeq.backend.symbol.Symbol`
The free symbols in the parameters of the tket circuit.
"""
Expand Down Expand Up @@ -250,6 +250,7 @@ def to_pennylane(lambeq_circuit: Diagram, probabilities=False,

class PennyLaneCircuit:
"""Implement a pennylane circuit with post-selection."""

def __init__(self, ops, symbols, params, wires, probabilities,
post_selection, scale, n_qubits, backend_config,
diff_method):
Expand Down Expand Up @@ -343,7 +344,7 @@ def draw(self):
Parameters
----------
symbols : list of :class:`sympy.core.symbol.Symbol`, default: None
symbols : list of :class:`lambeq.Symbol`, default: None
The symbols from the original lambeq circuit.
weights : list of :class:`torch.FloatTensor`, default: None
The weights to substitute for the symbols.
Expand Down Expand Up @@ -468,7 +469,7 @@ def eval(self):
Parameters
----------
symbols : list of :class:`sympy.core.symbol.Symbol`, default: None
symbols : list of :class:`lambeq.Symbol`, default: None
The symbols from the original lambeq circuit.
weights : list of :class:`torch.FloatTensor`, default: None
The weights to substitute for the symbols.
Expand Down
Loading

0 comments on commit de3f538

Please sign in to comment.