Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qiskit native QPY ParameterExpression serialization #13356

Merged
merged 18 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: symbol}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -172,3 +174,5 @@ def __setstate__(self, state):
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True
201 changes: 167 additions & 34 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Union

import numbers
Expand All @@ -30,12 +33,86 @@
ParameterValueType = Union["ParameterExpression", float]


class _OPCode(IntEnum):
ADD = 0
SUB = 1
MUL = 2
DIV = 3
POW = 4
SIN = 5
COS = 6
TAN = 7
ASIN = 8
ACOS = 9
EXP = 10
LOG = 11
SIGN = 12
GRAD = 13
CONJ = 14
SUBSTITUTE = 15
ABS = 16
ATAN = 17
RSUB = 18
RDIV = 19
RPOW = 20


_OP_CODE_MAP = (
"__add__",
"__sub__",
"__mul__",
"__truediv__",
"__pow__",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"exp",
"log",
"sign",
"gradient",
"conjugate",
"subs",
"abs",
"arctan",
"__rsub__",
"__rtruediv__",
"__rpow__",
)


def op_code_to_method(op_code: _OPCode):
"""Return the method name for a given op_code."""
return _OP_CODE_MAP[op_code]


@dataclass
class _INSTRUCTION:
op: _OPCode
lhs: ParameterValueType | None
rhs: ParameterValueType | None = None


@dataclass
class _SUBS:
binds: dict
op: _OPCode = _OPCode.SUBSTITUTE


class ParameterExpression:
"""ParameterExpression class to enable creating expressions of Parameters."""

__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]
__slots__ = [
"_parameter_symbols",
"_parameter_keys",
"_symbol_expr",
"_name_map",
"_qpy_replay",
"_standalone_param",
]

def __init__(self, symbol_map: dict, expr):
def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
"""Create a new :class:`ParameterExpression`.

Not intended to be called directly, but to be instantiated via operations
Expand All @@ -54,6 +131,11 @@ def __init__(self, symbol_map: dict, expr):
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None
self._standalone_param = False
if _qpy_replay is not None:
self._qpy_replay = _qpy_replay
else:
self._qpy_replay = []

@property
def parameters(self) -> set:
Expand All @@ -69,8 +151,14 @@ def _names(self) -> dict:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.CONJ, self)
else:
new_op = _INSTRUCTION(_OPCode.CONJ, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
conjugated = ParameterExpression(
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
self._parameter_symbols, symengine.conjugate(self._symbol_expr), _qpy_replay=new_replay
)
return conjugated

Expand Down Expand Up @@ -117,6 +205,7 @@ def bind(
self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_nan(parameter_values)

new_op = _SUBS(parameter_values)
symbol_values = {}
for parameter, value in parameter_values.items():
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
Expand All @@ -143,7 +232,12 @@ def bind(
f"(Expression: {self}, Bindings: {parameter_values})."
)

return ParameterExpression(free_parameter_symbols, bound_symbol_expr)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(
free_parameter_symbols, bound_symbol_expr, _qpy_replay=new_replay
)

def subs(
self, parameter_map: dict, allow_unknown_parameters: bool = False
Expand Down Expand Up @@ -175,6 +269,7 @@ def subs(
for p in replacement_expr.parameters
}
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())
new_op = _SUBS(parameter_map)

# Include existing parameters in self not set to be replaced.
new_parameter_symbols = {
Expand All @@ -192,8 +287,12 @@ def subs(
new_parameter_symbols[p] = symbol_type(p.name)

substituted_symbol_expr = self._symbol_expr.subs(symbol_map)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(new_parameter_symbols, substituted_symbol_expr)
return ParameterExpression(
new_parameter_symbols, substituted_symbol_expr, _qpy_replay=new_replay
)

def _raise_if_passed_unknown_parameters(self, parameters):
unknown_parameters = parameters - self.parameters
Expand Down Expand Up @@ -231,7 +330,11 @@ def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parame
)

def _apply_operation(
self, operation: Callable, other: ParameterValueType, reflected: bool = False
self,
operation: Callable,
other: ParameterValueType,
reflected: bool = False,
op_code: _OPCode = None,
) -> "ParameterExpression":
"""Base method implementing math operations between Parameters and
either a constant or a second ParameterExpression.
Expand All @@ -253,7 +356,6 @@ def _apply_operation(
A new expression describing the result of the operation.
"""
self_expr = self._symbol_expr

if isinstance(other, ParameterExpression):
self._raise_if_parameter_names_conflict(other._names)
parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
Expand All @@ -266,10 +368,26 @@ def _apply_operation(

if reflected:
expr = operation(other_expr, self_expr)
if op_code in {_OPCode.RSUB, _OPCode.RDIV, _OPCode.RPOW}:
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
else:
if self._standalone_param:
new_op = _INSTRUCTION(op_code, other, self)
else:
new_op = _INSTRUCTION(op_code, other, None)
else:
expr = operation(self_expr, other_expr)

out_expr = ParameterExpression(parameter_symbols, expr)
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

out_expr = ParameterExpression(parameter_symbols, expr, _qpy_replay=new_replay)
out_expr._name_map = self._names.copy()
if isinstance(other, ParameterExpression):
out_expr._names.update(other._names.copy())
Expand All @@ -291,6 +409,13 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
# If it is not contained then return 0
return 0.0

if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.GRAD, self, param)
else:
new_op = _INSTRUCTION(_OPCode.GRAD, None, param)
qpy_replay = self._qpy_replay.copy()
qpy_replay.append(new_op)

# Compute the gradient of the parameter expression w.r.t. param
key = self._parameter_symbols[param]
expr_grad = symengine.Derivative(self._symbol_expr, key)
Expand All @@ -304,7 +429,7 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
parameter_symbols[parameter] = symbol
# If the gradient corresponds to a parameter expression then return the new expression.
if len(parameter_symbols) > 0:
return ParameterExpression(parameter_symbols, expr=expr_grad)
return ParameterExpression(parameter_symbols, expr=expr_grad, _qpy_replay=qpy_replay)
# If no free symbols left, return a complex or float gradient
expr_grad_cplx = complex(expr_grad)
if expr_grad_cplx.imag != 0:
Expand All @@ -313,81 +438,89 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
return float(expr_grad)

def __add__(self, other):
return self._apply_operation(operator.add, other)
return self._apply_operation(operator.add, other, op_code=_OPCode.ADD)

def __radd__(self, other):
return self._apply_operation(operator.add, other, reflected=True)
return self._apply_operation(operator.add, other, reflected=True, op_code=_OPCode.ADD)

def __sub__(self, other):
return self._apply_operation(operator.sub, other)
return self._apply_operation(operator.sub, other, op_code=_OPCode.SUB)

def __rsub__(self, other):
return self._apply_operation(operator.sub, other, reflected=True)
return self._apply_operation(operator.sub, other, reflected=True, op_code=_OPCode.RSUB)

def __mul__(self, other):
return self._apply_operation(operator.mul, other)
return self._apply_operation(operator.mul, other, op_code=_OPCode.MUL)

def __pos__(self):
return self._apply_operation(operator.mul, 1)
return self._apply_operation(operator.mul, 1, op_code=_OPCode.MUL)

def __neg__(self):
return self._apply_operation(operator.mul, -1)
return self._apply_operation(operator.mul, -1, op_code=_OPCode.MUL)

def __rmul__(self, other):
return self._apply_operation(operator.mul, other, reflected=True)
return self._apply_operation(operator.mul, other, reflected=True, op_code=_OPCode.MUL)

def __truediv__(self, other):
if other == 0:
raise ZeroDivisionError("Division of a ParameterExpression by zero.")
return self._apply_operation(operator.truediv, other)
return self._apply_operation(operator.truediv, other, op_code=_OPCode.DIV)

def __rtruediv__(self, other):
return self._apply_operation(operator.truediv, other, reflected=True)
return self._apply_operation(operator.truediv, other, reflected=True, op_code=_OPCode.RDIV)

def __pow__(self, other):
return self._apply_operation(pow, other)
return self._apply_operation(pow, other, op_code=_OPCode.POW)

def __rpow__(self, other):
return self._apply_operation(pow, other, reflected=True)
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.RPOW)

def _call(self, ufunc):
return ParameterExpression(self._parameter_symbols, ufunc(self._symbol_expr))
def _call(self, ufunc, op_code):
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self)
else:
new_op = _INSTRUCTION(op_code, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
return ParameterExpression(
self._parameter_symbols, ufunc(self._symbol_expr), _qpy_replay=new_replay
)

def sin(self):
"""Sine of a ParameterExpression"""
return self._call(symengine.sin)
return self._call(symengine.sin, op_code=_OPCode.SIN)

def cos(self):
"""Cosine of a ParameterExpression"""
return self._call(symengine.cos)
return self._call(symengine.cos, op_code=_OPCode.COS)

def tan(self):
"""Tangent of a ParameterExpression"""
return self._call(symengine.tan)
return self._call(symengine.tan, op_code=_OPCode.TAN)

def arcsin(self):
"""Arcsin of a ParameterExpression"""
return self._call(symengine.asin)
return self._call(symengine.asin, op_code=_OPCode.ASIN)

def arccos(self):
"""Arccos of a ParameterExpression"""
return self._call(symengine.acos)
return self._call(symengine.acos, op_code=_OPCode.ACOS)

def arctan(self):
"""Arctan of a ParameterExpression"""
return self._call(symengine.atan)
return self._call(symengine.atan, op_code=_OPCode.ATAN)

def exp(self):
"""Exponential of a ParameterExpression"""
return self._call(symengine.exp)
return self._call(symengine.exp, op_code=_OPCode.EXP)

def log(self):
"""Logarithm of a ParameterExpression"""
return self._call(symengine.log)
return self._call(symengine.log, op_code=_OPCode.LOG)

def sign(self):
"""Sign of a ParameterExpression"""
return self._call(symengine.sign)
return self._call(symengine.sign, op_code=_OPCode.SIGN)

def __repr__(self):
return f"{self.__class__.__name__}({str(self)})"
Expand Down Expand Up @@ -455,7 +588,7 @@ def __deepcopy__(self, memo=None):

def __abs__(self):
"""Absolute of a ParameterExpression"""
return self._call(symengine.Abs)
return self._call(symengine.Abs, _OPCode.ABS)

def abs(self):
"""Absolute of a ParameterExpression"""
Expand Down
Loading