Skip to content

Commit

Permalink
Update Trotter
Browse files Browse the repository at this point in the history
  • Loading branch information
fdmalone committed Dec 6, 2024
1 parent 970343a commit 72f3d78
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 61 deletions.
30 changes: 9 additions & 21 deletions qualtran/bloqs/chemistry/trotter/hubbard/hopping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from qualtran import Bloq
from qualtran.bloqs.basic_gates import Rz, TGate, ZPowGate
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.bloqs.chemistry.trotter.hubbard.hopping import (
_hopping_tile,
_hopping_tile_hwp,
_plaquette,
)
from qualtran.resource_counting.generalizers import PHI
from qualtran.resource_counting import get_cost_value, QECGatesCost


def test_hopping_tile(bloq_autotester):
Expand All @@ -30,27 +27,18 @@ def test_hopping_plaquette(bloq_autotester):
bloq_autotester(_plaquette)


def catch_rotations(bloq) -> Bloq:
if isinstance(bloq, (Rz, ZPowGate)):
if isinstance(bloq, ZPowGate):
return Rz(angle=PHI)
elif abs(float(bloq.angle)) < 1e-12:
return ArbitraryClifford(1)
else:
return Rz(angle=PHI)
return bloq


def test_hopping_tile_t_counts():
bloq = _hopping_tile()
_, counts = bloq.call_graph(generalizer=catch_rotations)
assert counts[TGate()] == 8 * bloq.length**2 // 2
assert counts[Rz(PHI)] == 2 * bloq.length**2 // 2
costs = get_cost_value(bloq, QECGatesCost())
assert costs.t == 8 * bloq.length**2 // 2
assert costs.rotation == 2 * bloq.length**2 // 2


def test_hopping_tile_hwp_t_counts():
bloq = _hopping_tile_hwp()
_, counts = bloq.call_graph(generalizer=catch_rotations)
costs = get_cost_value(bloq, QECGatesCost())
n_rot_par = bloq.length**2 // 2
assert counts[Rz(PHI)] == 2 * n_rot_par.bit_length()
assert counts[TGate()] == 8 * bloq.length**2 // 2 + 2 * 4 * (n_rot_par - n_rot_par.bit_count())
assert costs.rotation == 2 * n_rot_par.bit_length()
assert costs.total_t_count(ts_per_rotation=0) == 8 * bloq.length**2 // 2 + 2 * 4 * (
n_rot_par - n_rot_par.bit_count()
)
32 changes: 25 additions & 7 deletions qualtran/bloqs/chemistry/trotter/hubbard/interaction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from qualtran.bloqs.basic_gates import Rz, TGate
from qualtran.bloqs.chemistry.trotter.hubbard.hopping_test import catch_rotations
import attrs
from qualtran.bloqs.basic_gates.rotation import ZPowGate
from qualtran.bloqs.chemistry.trotter.hubbard.interaction import _interaction, _interaction_hwp
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.resource_counting.generalizers import PHI
from qualtran.bloqs.basic_gates import Rz
from qualtran import Bloq


def catch_rotations(bloq) -> Bloq:
if isinstance(bloq, Rz):
if isinstance(bloq.angle, float) and abs(bloq.angle) < 1e-12:
return ArbitraryClifford(1)
else:
return attrs.evolve(bloq, angle=PHI)
if isinstance(bloq, ZPowGate):
if isinstance(bloq.exponent, float) and abs(bloq.exponent) < 1e-12:
return ArbitraryClifford(1)
else:
return attrs.evolve(bloq, exponent=PHI, global_shift=0)
return bloq


def test_hopping_tile(bloq_autotester):
Expand All @@ -27,14 +45,14 @@ def test_interaction_hwp(bloq_autotester):

def test_interaction_hwp_bloq_counts():
bloq = _interaction_hwp()
_, counts = bloq.call_graph(generalizer=catch_rotations)
costs = get_cost_value(bloq, QECGatesCost(), generalizer=catch_rotations)
n_rot_par = bloq.length**2 // 2
assert counts[Rz(PHI)] == 2 * n_rot_par.bit_length()
assert counts[TGate()] == 2 * 4 * (n_rot_par - n_rot_par.bit_count())
assert costs.rotation == 2 * n_rot_par.bit_length()
assert costs.total_t_count(ts_per_rotation=0) == 2 * 4 * (n_rot_par - n_rot_par.bit_count())


def test_interaction_bloq_counts():
bloq = _interaction()
_, counts = bloq.call_graph(generalizer=catch_rotations)
costs = get_cost_value(bloq, QECGatesCost())
n_rot = bloq.length**2
assert counts[Rz(PHI)] == n_rot
assert costs.rotation == n_rot
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,19 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, Union, Tuple\n",
"from typing import Tuple\n",
"\n",
"import numpy as np\n",
"import sympy\n",
"import attrs\n",
"\n",
"from qualtran.resource_counting.classify_bloqs import bloq_is_rotation\n",
"from qualtran.bloqs.basic_gates.rotation import ZPowGate\n",
"from qualtran.resource_counting.generalizers import PHI\n",
"from qualtran.cirq_interop.t_complexity_protocol import TComplexity\n",
"from qualtran import Bloq\n",
"from qualtran.bloqs.basic_gates import TGate, Rz\n",
"from qualtran.bloqs.basic_gates import Rz\n",
"from qualtran.bloqs.bookkeeping import ArbitraryClifford\n",
"from qualtran.resource_counting import get_cost_value, QECGatesCost\n",
"\n",
"\n",
"def catch_rotations(bloq) -> Bloq:\n",
Expand All @@ -126,16 +128,19 @@
" return ArbitraryClifford(1)\n",
" else:\n",
" return Rz(angle=PHI, eps=bloq.eps)\n",
" if isinstance(bloq, ZPowGate):\n",
" if isinstance(bloq.exponent, float) and abs(bloq.exponent) < 1e-12:\n",
" return ArbitraryClifford(1)\n",
" else:\n",
" return attrs.evolve(bloq, exponent=PHI, global_shift=0)\n",
" return bloq\n",
"\n",
"\n",
"def t_and_rot_counts_from_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]) -> Tuple[int, int]:\n",
" ret = sigma.get(TGate(), 0)\n",
" n_rot = 0\n",
" for bloq, counts in sigma.items():\n",
" if bloq_is_rotation(bloq):\n",
" n_rot += counts\n",
" return ret, n_rot\n",
"def t_and_rot_counts_from_bloq(bloq) -> Tuple[int, int]:\n",
" costs = get_cost_value(bloq, QECGatesCost(), generalizer=catch_rotations)\n",
" n_rot = costs.rotation\n",
" n_t = costs.total_t_count(ts_per_rotation=0)\n",
" return n_t, n_rot\n",
"\n",
"\n",
"def timestep_from_params(delta_ts: float, xi: float, prod_ord: int) -> float:\n",
Expand Down Expand Up @@ -265,7 +270,7 @@
"from qualtran.bloqs.chemistry.trotter.hubbard.trotter_step import build_plaq_unitary_second_order_suzuki\n",
"\n",
"trotter_step = build_plaq_unitary_second_order_suzuki(length, hubb_u, timestep, eps=1e-10)\n",
"n_t, n_rot = t_and_rot_counts_from_sigma(trotter_step.call_graph(generalizer=catch_rotations)[1])\n",
"n_t, n_rot = t_and_rot_counts_from_bloq(trotter_step)\n",
"print(f\"N_T = {n_t} vs {(3*length**2 // 2)*8}\")\n",
"print(f\"N_rot = {n_rot} vs {(3 * length**2 + 2*length**2)}\")"
]
Expand All @@ -283,7 +288,6 @@
"metadata": {},
"outputs": [],
"source": [
"import attrs\n",
"from qualtran.drawing import show_call_graph\n",
"# get appropriate epsilon given our input parameters now we know the number of rotations\n",
"eps_single_rot = get_single_rot_eps(n_rot, delta_ht, timestep)\n",
Expand Down Expand Up @@ -418,7 +422,7 @@
"metadata": {},
"outputs": [],
"source": [
"from scipy.optimize import minimize, bisect, newton\n",
"from scipy.optimize import minimize\n",
"def objective(delta_ts, delta_ht, n_rot, n_t, xi_bound, prod_ord):\n",
" t_counts = qpe_t_count(epsilon - delta_ts - delta_ht, delta_ts, delta_ht, n_rot, n_t, xi_bound, prod_ord)\n",
" return t_counts\n",
Expand Down Expand Up @@ -467,7 +471,7 @@
"source": [
"from qualtran.bloqs.chemistry.trotter.hubbard.trotter_step import build_plaq_hwp_unitary_second_order_suzuki\n",
"trotter_step_hwp = build_plaq_hwp_unitary_second_order_suzuki(length, hubb_u, timestep, eps=1e-10)\n",
"n_t_hwp, n_rot_hwp = t_and_rot_counts_from_sigma(trotter_step_hwp.call_graph(generalizer=catch_rotations)[1])\n",
"n_t_hwp, n_rot_hwp = t_and_rot_counts_from_bloq(trotter_step_hwp)\n",
"print(f\"N_T(HWP) = {n_t_hwp} vs {(3*length**2 // 2)*8}\")\n",
"print(f\"N_rot(HWP) = {n_rot_hwp} vs {(3 * length**2 + 2*length**2)}\")\n",
"delta_ht_opt, delta_ts_opt, delta_pe_opt, t_opt = minimize_linesearch(n_rot_hwp, n_t_hwp, xi_bound, prod_ord)\n",
Expand All @@ -490,7 +494,7 @@
"outputs": [],
"source": [
"trotter_step_hwp = build_plaq_hwp_unitary_second_order_suzuki(length, hubb_u, timestep, eps=1e-10, strip_layer=True)\n",
"n_t_hwp, n_rot_hwp = t_and_rot_counts_from_sigma(trotter_step_hwp.call_graph(generalizer=catch_rotations)[1])\n",
"n_t_hwp, n_rot_hwp = t_and_rot_counts_from_bloq(trotter_step_hwp)\n",
"print(f\"N_T(HWP) = {n_t_hwp}\")\n",
"print(f\"N_rot(HWP) = {n_rot_hwp}\")\n",
"delta_ht_opt, delta_ts_opt, delta_pe_opt, t_opt = minimize_linesearch(n_rot_hwp, n_t_hwp, xi_bound, prod_ord)\n",
Expand Down Expand Up @@ -619,7 +623,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
24 changes: 7 additions & 17 deletions qualtran/bloqs/chemistry/trotter/hubbard/trotter_step_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,26 @@
# limitations under the License.
import pytest

from qualtran import Bloq
from qualtran.bloqs.basic_gates import Rz
from qualtran.bloqs.basic_gates.t_gate import TGate
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.bloqs.chemistry.trotter.hubbard.trotter_step import (
build_plaq_unitary_second_order_suzuki,
)
from qualtran.resource_counting.generalizers import PHI
from qualtran.bloqs.chemistry.trotter.hubbard.interaction_test import catch_rotations
from qualtran.testing import execute_notebook


def catch_rotations(bloq) -> Bloq:
if isinstance(bloq, Rz):
if isinstance(bloq.angle, float) and abs(bloq.angle) < 1e-12:
return ArbitraryClifford(1)
else:
return Rz(angle=PHI)
return bloq
from qualtran.resource_counting import get_cost_value, QECGatesCost


def test_second_order_suzuki_costs():
length = 8
u = 4
dt = 0.1
unitary = build_plaq_unitary_second_order_suzuki(length, u, dt)
_, sigma = unitary.call_graph(generalizer=catch_rotations)
# _, sigma = unitary.call_graph(generalizer=catch_rotations)
costs = get_cost_value(unitary, QECGatesCost(), generalizer=catch_rotations)
# there are 3 hopping unitaries contributing 8 Ts from from the F gate
assert sigma[TGate()] == (3 * length**2 // 2) * 8
assert costs.total_t_count(ts_per_rotation=0) == (3 * length**2 // 2) * 8
# 3 hopping unitaries and 2 interaction unitaries
assert sigma[Rz(PHI)] == (3 * length**2 + 2 * length**2)
print(costs.rotation, (3 * length**2 + 2 * length**2))
assert costs.rotation == (3 * length**2 + 2 * length**2), 3 * length**2 + 2 * length**2


@pytest.mark.notebook
Expand Down

0 comments on commit 72f3d78

Please sign in to comment.