Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
israelferrazaraujo committed Dec 6, 2023
1 parent 37e2a50 commit efde78f
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions qdna/embedding/util/state_tree_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
"""

from dataclasses import dataclass
import numpy as np
from qiskit.circuit import ParameterExpression
from qiskit.utils import optionals as _optionals
from sympy import sqrt as sp_sqrt #, Abs as sp_abs
from sympy import sqrt as sp_sqrt
import symengine

@dataclass
Expand All @@ -43,14 +44,14 @@ def __str__(self):
)

# def _abs(self):

# # SymPy's `Abs` function doesn't work with hybrid optimization backwards.

# """Square root of a ParameterExpression"""
# """Absolute value of a ParameterExpression"""
# if _optionals.HAS_SYMENGINE:
# return self._call(symengine.Abs)

# return self._call(sp_abs)
# import symengine
# return ParameterExpression(self._parameter_symbols, symengine.Abs(self._symbol_expr))
# else:
# from sympy import Abs as sp_abs
# return ParameterExpression(self._parameter_symbols, sp_abs(self._symbol_expr))

def _sqrt(self):
"""Square root of a ParameterExpression"""
Expand All @@ -60,16 +61,22 @@ def _sqrt(self):
return self._call(sp_sqrt)

def _sign(self):
# SymPy's `sign` function doesn't work with hybrid optimization backwards.
"""Sign of a ParameterExpression"""
# if _optionals.HAS_SYMENGINE:
# return self._call(symengine.sign)

# return self._call(sp_sign)

# SymPy's `sign` function doesn't work with hybrid optimization backwards.
# If self is exactly `-10^-6`, the algorithm is interrupted.
# import symengine
# return ParameterExpression(
# self._parameter_symbols, symengine.sign(self._symbol_expr + 1e-4)
# )
# else:
# from sympy import sign
# return ParameterExpression(
# self._parameter_symbols, sign(self._symbol_expr + 1e-4)
# )

# If self is exactly `-10^-4`, the algorithm is interrupted.
# This is because the `sign` function will return 0, zeroing the norms.
return (self + 1e-6) / _sqrt((self + 1e-6)*(self + 1e-6))
return self / _sqrt(self*self)

def state_decomposition(nqubits, data, normalize=False):
"""
Expand All @@ -82,19 +89,21 @@ def state_decomposition(nqubits, data, normalize=False):
new_nodes = []

# leafs
r = 10**-6
for i, k in enumerate(data):
sign = _sign(k)
# If one of the coordinates of the state vector
# is exactly `-10^-64`, the algorithm breaks.
# is exactly `-r*10^-6`, the algorithm breaks.
# This is because one of the norms will be zero,
# causing division by zero.
k = k + r # prevents division by zero.
sign = _sign(k)
new_nodes.append(
Node(
i,
nqubits,
None,
None,
k + 1e-64, # prevents division by zero.
k,
sign
)
)
Expand Down

0 comments on commit efde78f

Please sign in to comment.