Skip to content

Commit

Permalink
Add Qiskit native QPY ParameterExpression serialization (#13356)
Browse files Browse the repository at this point in the history
* Add Qiskit native QPY ParameterExpression serialization

With the release of symengine 0.13.0 we discovered a version dependence
on the payload format used for serializing symengine expressions. This
was worked around in #13251 but this is not a sustainable solution and
only works for symengine 0.11.0 and 0.13.0 (there was no 0.12.0). While
there was always the option to use sympy to serialize the underlying
symbolic expression (there is a `use_symengine` flag on `qpy.dumps` you
can set to `False` to do this) the sympy serialzation has several
tradeoffs most importantly is much higher runtime overhead. To solve
the issue moving forward a qiskit native representation of the parameter
expression object is necessary for serialization.

This commit bumps the QPY format version to 13 and adds a new
serialization format for ParameterExpression objects. This new format
is a serialization of the API calls made to ParameterExpression that
resulted in the creation of the underlying object. To facilitate this
the ParameterExpression class is expanded to store an internal "replay"
record of the API calls used to construct the ParameterExpression
object. This internal list is what gets serialized by QPY and then on
deserialization the "replay" is replayed to reconstruct the expression
object. This is a different approach to the previous QPY representations
of the ParameterExpression objects which instead represented the internal
state stored in the ParameterExpression object with the symbolic
expression from symengine (or a sympy copy of the expression). Doing
this directly in Qiskit isn't viable though because symengine's internal
expression tree is not exposed to Python directly. There isn't any
method (private or public) to walk the expression tree to construct
a serialization format based off of it. Converting symengine to a sympy
expression and then using sympy's API to walk the expression tree is a
possibility but that would tie us to sympy which would be problematic
for #13267 and #13131, have significant runtime overhead, and it would
be just easier to rely on sympy's native serialization tools.

The tradeoff with this approach is that it does increase the memory
overhead of the `ParameterExpression` class because for each element
in the expression we have to store a record of it. Depending on the
depth of the expression tree this also could be a lot larger than
symengine's internal representation as we store the raw api calls made
to create the ParameterExpression but symengine is likely simplifying
it's internal representation as it builds it out. But I personally think
this tradeoff is worthwhile as it ties the serialization format to the
Qiskit objects instead of relying on a 3rd party library. This also
gives us the flexibility of changing the internal symbolic expression
library internally in the future if we decide to stop using symengine
at any point.

Fixes #13252

* Remove stray comment

* Add format documentation

* Add release note

* Add test and fix some issues with recursive expressions

* Add int type for operands

* Add dedicated subs test

* Pivot to stack based postfix/rpn deserialization

This commit changes how the deserialization works to use a postfix
stack based approach. Operands are push on the stack and then popped off
based on the operation being run. The result of the operation is then
pushed on the stack. This handles nested objects much more cleanly than
the recursion based approach because we just keep pushing on the stack
instead of recursing, making the accounting much simpler. After the
expression payload is finished being processed there will be a single
value on the stack and that is returned as the final expression.

* Apply suggestions from code review

Co-authored-by: Elena Peña Tapia <[email protected]>

* Change DERIV to GRAD

* Change side kwarg to r_side

* Change all the v4s to v13s

* Correctly handle non-commutative operations

This commit fixes a bug with handling the operand order of subtraction,
division, and exponentiation. These operations are not commutative but
the qpy deserialization code was treating them as such. So in cases
where the argument order was reversed qpy was trying to flip the
operands around for code simplicity and this would result in incorrect
behavior. This commit fixes this by adding explicit op codes for the
reversed sub, div, and pow and preserving the operand order correctly
in these cases.

* Fix lint

---------

Co-authored-by: Elena Peña Tapia <[email protected]>
  • Loading branch information
mtreinish and ElePT authored Nov 6, 2024
1 parent 1b35e8b commit 0a7690d
Show file tree
Hide file tree
Showing 9 changed files with 744 additions and 47 deletions.
4 changes: 4 additions & 0 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: symbol}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -172,3 +174,5 @@ def __setstate__(self, state):
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
self._qpy_replay = []
self._standalone_param = True
201 changes: 167 additions & 34 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Union

import numbers
Expand All @@ -30,12 +33,86 @@
ParameterValueType = Union["ParameterExpression", float]


class _OPCode(IntEnum):
ADD = 0
SUB = 1
MUL = 2
DIV = 3
POW = 4
SIN = 5
COS = 6
TAN = 7
ASIN = 8
ACOS = 9
EXP = 10
LOG = 11
SIGN = 12
GRAD = 13
CONJ = 14
SUBSTITUTE = 15
ABS = 16
ATAN = 17
RSUB = 18
RDIV = 19
RPOW = 20


_OP_CODE_MAP = (
"__add__",
"__sub__",
"__mul__",
"__truediv__",
"__pow__",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"exp",
"log",
"sign",
"gradient",
"conjugate",
"subs",
"abs",
"arctan",
"__rsub__",
"__rtruediv__",
"__rpow__",
)


def op_code_to_method(op_code: _OPCode):
"""Return the method name for a given op_code."""
return _OP_CODE_MAP[op_code]


@dataclass
class _INSTRUCTION:
op: _OPCode
lhs: ParameterValueType | None
rhs: ParameterValueType | None = None


@dataclass
class _SUBS:
binds: dict
op: _OPCode = _OPCode.SUBSTITUTE


class ParameterExpression:
"""ParameterExpression class to enable creating expressions of Parameters."""

__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]
__slots__ = [
"_parameter_symbols",
"_parameter_keys",
"_symbol_expr",
"_name_map",
"_qpy_replay",
"_standalone_param",
]

def __init__(self, symbol_map: dict, expr):
def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
"""Create a new :class:`ParameterExpression`.
Not intended to be called directly, but to be instantiated via operations
Expand All @@ -54,6 +131,11 @@ def __init__(self, symbol_map: dict, expr):
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None
self._standalone_param = False
if _qpy_replay is not None:
self._qpy_replay = _qpy_replay
else:
self._qpy_replay = []

@property
def parameters(self) -> set:
Expand All @@ -69,8 +151,14 @@ def _names(self) -> dict:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.CONJ, self)
else:
new_op = _INSTRUCTION(_OPCode.CONJ, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
conjugated = ParameterExpression(
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
self._parameter_symbols, symengine.conjugate(self._symbol_expr), _qpy_replay=new_replay
)
return conjugated

Expand Down Expand Up @@ -117,6 +205,7 @@ def bind(
self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_nan(parameter_values)

new_op = _SUBS(parameter_values)
symbol_values = {}
for parameter, value in parameter_values.items():
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
Expand All @@ -143,7 +232,12 @@ def bind(
f"(Expression: {self}, Bindings: {parameter_values})."
)

return ParameterExpression(free_parameter_symbols, bound_symbol_expr)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(
free_parameter_symbols, bound_symbol_expr, _qpy_replay=new_replay
)

def subs(
self, parameter_map: dict, allow_unknown_parameters: bool = False
Expand Down Expand Up @@ -175,6 +269,7 @@ def subs(
for p in replacement_expr.parameters
}
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())
new_op = _SUBS(parameter_map)

# Include existing parameters in self not set to be replaced.
new_parameter_symbols = {
Expand All @@ -192,8 +287,12 @@ def subs(
new_parameter_symbols[p] = symbol_type(p.name)

substituted_symbol_expr = self._symbol_expr.subs(symbol_map)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(new_parameter_symbols, substituted_symbol_expr)
return ParameterExpression(
new_parameter_symbols, substituted_symbol_expr, _qpy_replay=new_replay
)

def _raise_if_passed_unknown_parameters(self, parameters):
unknown_parameters = parameters - self.parameters
Expand Down Expand Up @@ -231,7 +330,11 @@ def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parame
)

def _apply_operation(
self, operation: Callable, other: ParameterValueType, reflected: bool = False
self,
operation: Callable,
other: ParameterValueType,
reflected: bool = False,
op_code: _OPCode = None,
) -> "ParameterExpression":
"""Base method implementing math operations between Parameters and
either a constant or a second ParameterExpression.
Expand All @@ -253,7 +356,6 @@ def _apply_operation(
A new expression describing the result of the operation.
"""
self_expr = self._symbol_expr

if isinstance(other, ParameterExpression):
self._raise_if_parameter_names_conflict(other._names)
parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
Expand All @@ -266,10 +368,26 @@ def _apply_operation(

if reflected:
expr = operation(other_expr, self_expr)
if op_code in {_OPCode.RSUB, _OPCode.RDIV, _OPCode.RPOW}:
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
else:
if self._standalone_param:
new_op = _INSTRUCTION(op_code, other, self)
else:
new_op = _INSTRUCTION(op_code, other, None)
else:
expr = operation(self_expr, other_expr)

out_expr = ParameterExpression(parameter_symbols, expr)
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self, other)
else:
new_op = _INSTRUCTION(op_code, None, other)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

out_expr = ParameterExpression(parameter_symbols, expr, _qpy_replay=new_replay)
out_expr._name_map = self._names.copy()
if isinstance(other, ParameterExpression):
out_expr._names.update(other._names.copy())
Expand All @@ -291,6 +409,13 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
# If it is not contained then return 0
return 0.0

if self._standalone_param:
new_op = _INSTRUCTION(_OPCode.GRAD, self, param)
else:
new_op = _INSTRUCTION(_OPCode.GRAD, None, param)
qpy_replay = self._qpy_replay.copy()
qpy_replay.append(new_op)

# Compute the gradient of the parameter expression w.r.t. param
key = self._parameter_symbols[param]
expr_grad = symengine.Derivative(self._symbol_expr, key)
Expand All @@ -304,7 +429,7 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
parameter_symbols[parameter] = symbol
# If the gradient corresponds to a parameter expression then return the new expression.
if len(parameter_symbols) > 0:
return ParameterExpression(parameter_symbols, expr=expr_grad)
return ParameterExpression(parameter_symbols, expr=expr_grad, _qpy_replay=qpy_replay)
# If no free symbols left, return a complex or float gradient
expr_grad_cplx = complex(expr_grad)
if expr_grad_cplx.imag != 0:
Expand All @@ -313,81 +438,89 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
return float(expr_grad)

def __add__(self, other):
return self._apply_operation(operator.add, other)
return self._apply_operation(operator.add, other, op_code=_OPCode.ADD)

def __radd__(self, other):
return self._apply_operation(operator.add, other, reflected=True)
return self._apply_operation(operator.add, other, reflected=True, op_code=_OPCode.ADD)

def __sub__(self, other):
return self._apply_operation(operator.sub, other)
return self._apply_operation(operator.sub, other, op_code=_OPCode.SUB)

def __rsub__(self, other):
return self._apply_operation(operator.sub, other, reflected=True)
return self._apply_operation(operator.sub, other, reflected=True, op_code=_OPCode.RSUB)

def __mul__(self, other):
return self._apply_operation(operator.mul, other)
return self._apply_operation(operator.mul, other, op_code=_OPCode.MUL)

def __pos__(self):
return self._apply_operation(operator.mul, 1)
return self._apply_operation(operator.mul, 1, op_code=_OPCode.MUL)

def __neg__(self):
return self._apply_operation(operator.mul, -1)
return self._apply_operation(operator.mul, -1, op_code=_OPCode.MUL)

def __rmul__(self, other):
return self._apply_operation(operator.mul, other, reflected=True)
return self._apply_operation(operator.mul, other, reflected=True, op_code=_OPCode.MUL)

def __truediv__(self, other):
if other == 0:
raise ZeroDivisionError("Division of a ParameterExpression by zero.")
return self._apply_operation(operator.truediv, other)
return self._apply_operation(operator.truediv, other, op_code=_OPCode.DIV)

def __rtruediv__(self, other):
return self._apply_operation(operator.truediv, other, reflected=True)
return self._apply_operation(operator.truediv, other, reflected=True, op_code=_OPCode.RDIV)

def __pow__(self, other):
return self._apply_operation(pow, other)
return self._apply_operation(pow, other, op_code=_OPCode.POW)

def __rpow__(self, other):
return self._apply_operation(pow, other, reflected=True)
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.RPOW)

def _call(self, ufunc):
return ParameterExpression(self._parameter_symbols, ufunc(self._symbol_expr))
def _call(self, ufunc, op_code):
if self._standalone_param:
new_op = _INSTRUCTION(op_code, self)
else:
new_op = _INSTRUCTION(op_code, None)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
return ParameterExpression(
self._parameter_symbols, ufunc(self._symbol_expr), _qpy_replay=new_replay
)

def sin(self):
"""Sine of a ParameterExpression"""
return self._call(symengine.sin)
return self._call(symengine.sin, op_code=_OPCode.SIN)

def cos(self):
"""Cosine of a ParameterExpression"""
return self._call(symengine.cos)
return self._call(symengine.cos, op_code=_OPCode.COS)

def tan(self):
"""Tangent of a ParameterExpression"""
return self._call(symengine.tan)
return self._call(symengine.tan, op_code=_OPCode.TAN)

def arcsin(self):
"""Arcsin of a ParameterExpression"""
return self._call(symengine.asin)
return self._call(symengine.asin, op_code=_OPCode.ASIN)

def arccos(self):
"""Arccos of a ParameterExpression"""
return self._call(symengine.acos)
return self._call(symengine.acos, op_code=_OPCode.ACOS)

def arctan(self):
"""Arctan of a ParameterExpression"""
return self._call(symengine.atan)
return self._call(symengine.atan, op_code=_OPCode.ATAN)

def exp(self):
"""Exponential of a ParameterExpression"""
return self._call(symengine.exp)
return self._call(symengine.exp, op_code=_OPCode.EXP)

def log(self):
"""Logarithm of a ParameterExpression"""
return self._call(symengine.log)
return self._call(symengine.log, op_code=_OPCode.LOG)

def sign(self):
"""Sign of a ParameterExpression"""
return self._call(symengine.sign)
return self._call(symengine.sign, op_code=_OPCode.SIGN)

def __repr__(self):
return f"{self.__class__.__name__}({str(self)})"
Expand Down Expand Up @@ -455,7 +588,7 @@ def __deepcopy__(self, memo=None):

def __abs__(self):
"""Absolute of a ParameterExpression"""
return self._call(symengine.Abs)
return self._call(symengine.Abs, _OPCode.ABS)

def abs(self):
"""Absolute of a ParameterExpression"""
Expand Down
Loading

0 comments on commit 0a7690d

Please sign in to comment.