Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qiskit native QPY ParameterExpression serialization #13356

Merged
merged 18 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions qiskit/circuit/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: symbol}
self._name_map = None
self._qpy_replay = []

def assign(self, parameter, value):
if parameter != self:
Expand Down Expand Up @@ -172,3 +173,4 @@ def __setstate__(self, state):
self._hash = hash((self._parameter_keys, self._symbol_expr))
self._parameter_symbols = {self: self._symbol_expr}
self._name_map = None
self._qpy_replay = []
166 changes: 133 additions & 33 deletions qiskit/circuit/parameterexpression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Union

import numbers
Expand All @@ -30,12 +33,79 @@
ParameterValueType = Union["ParameterExpression", float]


class _OPCode(IntEnum):
ADD = 0
SUB = 1
MUL = 2
DIV = 3
POW = 4
SIN = 5
COS = 6
TAN = 7
ASIN = 8
ACOS = 9
EXP = 10
LOG = 11
SIGN = 12
DERIV = 13
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
CONJ = 14
SUBSTITUTE = 15
ABS = 16
ATAN = 17


_OP_CODE_MAP = (
"__add__",
"__sub__",
"__mul__",
"__truediv__",
"__pow__",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"exp",
"log",
"sign",
"gradient",
"conjugate",
"subs",
"abs",
"arctan",
)


def op_code_to_method(op_code: _OPCode):
"""Return the method name for a given op_code."""
return _OP_CODE_MAP[op_code]


@dataclass
class _INSTRUCTION:
op: _OPCode
lhs: ParameterValueType
rhs: ParameterValueType | None = None


@dataclass
class _SUBS:
binds: dict
op: _OPCode = _OPCode.SUBSTITUTE


class ParameterExpression:
"""ParameterExpression class to enable creating expressions of Parameters."""

__slots__ = ["_parameter_symbols", "_parameter_keys", "_symbol_expr", "_name_map"]
__slots__ = [
"_parameter_symbols",
"_parameter_keys",
"_symbol_expr",
"_name_map",
"_qpy_replay",
]

def __init__(self, symbol_map: dict, expr):
def __init__(self, symbol_map: dict, expr, *, _qpy_replay=None):
"""Create a new :class:`ParameterExpression`.

Not intended to be called directly, but to be instantiated via operations
Expand All @@ -54,6 +124,10 @@ def __init__(self, symbol_map: dict, expr):
self._parameter_keys = frozenset(p._hash_key() for p in self._parameter_symbols)
self._symbol_expr = expr
self._name_map: dict | None = None
if _qpy_replay is not None:
self._qpy_replay = _qpy_replay
else:
self._qpy_replay = []

@property
def parameters(self) -> set:
Expand All @@ -69,8 +143,11 @@ def _names(self) -> dict:

def conjugate(self) -> "ParameterExpression":
"""Return the conjugate."""
new_op = _INSTRUCTION(_OPCode.CONJ, self)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
conjugated = ParameterExpression(
self._parameter_symbols, symengine.conjugate(self._symbol_expr)
self._parameter_symbols, symengine.conjugate(self._symbol_expr), _qpy_replay=new_replay
)
return conjugated

Expand Down Expand Up @@ -117,6 +194,7 @@ def bind(
self._raise_if_passed_unknown_parameters(parameter_values.keys())
self._raise_if_passed_nan(parameter_values)

new_op = _SUBS(parameter_values)
symbol_values = {}
for parameter, value in parameter_values.items():
if (param_expr := self._parameter_symbols.get(parameter)) is not None:
Expand All @@ -143,7 +221,12 @@ def bind(
f"(Expression: {self}, Bindings: {parameter_values})."
)

return ParameterExpression(free_parameter_symbols, bound_symbol_expr)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(
free_parameter_symbols, bound_symbol_expr, _qpy_replay=new_replay
)

def subs(
self, parameter_map: dict, allow_unknown_parameters: bool = False
Expand Down Expand Up @@ -175,6 +258,7 @@ def subs(
for p in replacement_expr.parameters
}
self._raise_if_parameter_names_conflict(inbound_names, parameter_map.keys())
new_op = _SUBS(parameter_map)

# Include existing parameters in self not set to be replaced.
new_parameter_symbols = {
Expand All @@ -192,8 +276,12 @@ def subs(
new_parameter_symbols[p] = symbol_type(p.name)

substituted_symbol_expr = self._symbol_expr.subs(symbol_map)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

return ParameterExpression(new_parameter_symbols, substituted_symbol_expr)
return ParameterExpression(
new_parameter_symbols, substituted_symbol_expr, _qpy_replay=new_replay
)

def _raise_if_passed_unknown_parameters(self, parameters):
unknown_parameters = parameters - self.parameters
Expand Down Expand Up @@ -231,7 +319,11 @@ def _raise_if_parameter_names_conflict(self, inbound_parameters, outbound_parame
)

def _apply_operation(
self, operation: Callable, other: ParameterValueType, reflected: bool = False
self,
operation: Callable,
other: ParameterValueType,
reflected: bool = False,
op_code: _OPCode = None,
) -> "ParameterExpression":
"""Base method implementing math operations between Parameters and
either a constant or a second ParameterExpression.
Expand All @@ -253,7 +345,6 @@ def _apply_operation(
A new expression describing the result of the operation.
"""
self_expr = self._symbol_expr

if isinstance(other, ParameterExpression):
self._raise_if_parameter_names_conflict(other._names)
parameter_symbols = {**self._parameter_symbols, **other._parameter_symbols}
Expand All @@ -266,10 +357,14 @@ def _apply_operation(

if reflected:
expr = operation(other_expr, self_expr)
new_op = _INSTRUCTION(op_code, other, self)
else:
expr = operation(self_expr, other_expr)
new_op = _INSTRUCTION(op_code, self, other)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)

out_expr = ParameterExpression(parameter_symbols, expr)
out_expr = ParameterExpression(parameter_symbols, expr, _qpy_replay=new_replay)
out_expr._name_map = self._names.copy()
if isinstance(other, ParameterExpression):
out_expr._names.update(other._names.copy())
Expand Down Expand Up @@ -313,81 +408,86 @@ def gradient(self, param) -> Union["ParameterExpression", complex]:
return float(expr_grad)

def __add__(self, other):
return self._apply_operation(operator.add, other)
return self._apply_operation(operator.add, other, op_code=_OPCode.ADD)

def __radd__(self, other):
return self._apply_operation(operator.add, other, reflected=True)
return self._apply_operation(operator.add, other, reflected=True, op_code=_OPCode.ADD)

def __sub__(self, other):
return self._apply_operation(operator.sub, other)
return self._apply_operation(operator.sub, other, op_code=_OPCode.SUB)

def __rsub__(self, other):
return self._apply_operation(operator.sub, other, reflected=True)
return self._apply_operation(operator.sub, other, reflected=True, op_code=_OPCode.SUB)

def __mul__(self, other):
return self._apply_operation(operator.mul, other)
return self._apply_operation(operator.mul, other, op_code=_OPCode.MUL)

def __pos__(self):
return self._apply_operation(operator.mul, 1)
return self._apply_operation(operator.mul, 1, op_code=_OPCode.MUL)

def __neg__(self):
return self._apply_operation(operator.mul, -1)
return self._apply_operation(operator.mul, -1, op_code=_OPCode.MUL)

def __rmul__(self, other):
return self._apply_operation(operator.mul, other, reflected=True)
return self._apply_operation(operator.mul, other, reflected=True, op_code=_OPCode.MUL)

def __truediv__(self, other):
if other == 0:
raise ZeroDivisionError("Division of a ParameterExpression by zero.")
return self._apply_operation(operator.truediv, other)
return self._apply_operation(operator.truediv, other, op_code=_OPCode.DIV)

def __rtruediv__(self, other):
return self._apply_operation(operator.truediv, other, reflected=True)
return self._apply_operation(operator.truediv, other, reflected=True, op_code=_OPCode.DIV)

def __pow__(self, other):
return self._apply_operation(pow, other)
return self._apply_operation(pow, other, op_code=_OPCode.POW)

def __rpow__(self, other):
return self._apply_operation(pow, other, reflected=True)

def _call(self, ufunc):
return ParameterExpression(self._parameter_symbols, ufunc(self._symbol_expr))
return self._apply_operation(pow, other, reflected=True, op_code=_OPCode.POW)

def _call(self, ufunc, op_code):
new_op = _INSTRUCTION(op_code, self)
new_replay = self._qpy_replay.copy()
new_replay.append(new_op)
return ParameterExpression(
self._parameter_symbols, ufunc(self._symbol_expr), _qpy_replay=new_replay
)

def sin(self):
"""Sine of a ParameterExpression"""
return self._call(symengine.sin)
return self._call(symengine.sin, op_code=_OPCode.SIN)

def cos(self):
"""Cosine of a ParameterExpression"""
return self._call(symengine.cos)
return self._call(symengine.cos, op_code=_OPCode.COS)

def tan(self):
"""Tangent of a ParameterExpression"""
return self._call(symengine.tan)
return self._call(symengine.tan, op_code=_OPCode.TAN)

def arcsin(self):
"""Arcsin of a ParameterExpression"""
return self._call(symengine.asin)
return self._call(symengine.asin, op_code=_OPCode.ASIN)

def arccos(self):
"""Arccos of a ParameterExpression"""
return self._call(symengine.acos)
return self._call(symengine.acos, op_code=_OPCode.ACOS)

def arctan(self):
"""Arctan of a ParameterExpression"""
return self._call(symengine.atan)
return self._call(symengine.atan, op_code=_OPCode.ATAN)

def exp(self):
"""Exponential of a ParameterExpression"""
return self._call(symengine.exp)
return self._call(symengine.exp, op_code=_OPCode.EXP)

def log(self):
"""Logarithm of a ParameterExpression"""
return self._call(symengine.log)
return self._call(symengine.log, op_code=_OPCode.LOG)

def sign(self):
"""Sign of a ParameterExpression"""
return self._call(symengine.sign)
return self._call(symengine.sign, op_code=_OPCode.SIGN)

def __repr__(self):
return f"{self.__class__.__name__}({str(self)})"
Expand Down Expand Up @@ -455,7 +555,7 @@ def __deepcopy__(self, memo=None):

def __abs__(self):
"""Absolute of a ParameterExpression"""
return self._call(symengine.Abs)
return self._call(symengine.Abs, _OPCode.ABS)

def abs(self):
"""Absolute of a ParameterExpression"""
Expand Down
Loading