Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create frontend interface #9

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions quaqsim/api/_simulation_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass

from ..program_ast.program import Program


@dataclass
class SimulationResult:
"""Result of a simulation. The `pulse_schedule_graph` and `simulated_results_graph`
are base64-encoded PNGs."""
pulse_schedule_graph: str | None
simulated_results: list[list[float]] | None
simulated_results_graph: str | None
error: Exception | None


@dataclass
class SimulationRequest:
qua_configuration: dict | None = None
qua_program: Program | None = None
quantum_system: bytes | None = None
channel_map: bytes | None = None
result: SimulationResult | None = None

@property
def can_simulate(self) -> bool:
"""Whether this request can be simulated."""
return (
self.qua_configuration is not None
and self.qua_program is not None
and self.quantum_system is not None
and self.channel_map is not None
)
192 changes: 192 additions & 0 deletions quaqsim/api/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from dataclasses import asdict
from typing import Annotated, Optional

from fastapi import Body, FastAPI, HTTPException
from fastapi import status as http_status
from fastapi.middleware.wsgi import WSGIMiddleware
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from qiskit.pulse import Schedule
from qiskit.visualization.pulse_v2 import draw, IQXDebugging
from qm.qua import *

from ..architectures.transmon_pair_backend_from_qua import TransmonPairBackendFromQUA
from ..program_to_quantum_pulse_sim_compiler.quantum_pulse_sim_compiler import Compiler
from ._simulation_request import SimulationRequest, SimulationResult
from .frontend import frontend
from .utils import (
load_from_base64,
dump_fig_to_base64,
program_to_ast,
script_to_program,
)

matplotlib.use("agg")


def _get_pulse_schedule_graph(schedules: list[Schedule]):
n = len(schedules)
fig, ax = plt.subplots(
ncols=5,
nrows=n // 5 if n % 5 == 0 else n // 5 + 1,
figsize=(25, 4 * (n // 5)),
dpi=75,
)

for i, schedule in enumerate(schedules):
draw(
program=schedule,
style=IQXDebugging(),
backend=None,
time_range=None,
time_unit="ns",
disable_channels=None,
show_snapshot=True,
show_framechange=True,
show_waveform_info=True,
show_barrier=True,
plotter="mpl2d",
axis=ax[i // 5][i % 5],
)

fig.tight_layout()

return dump_fig_to_base64(fig)


def _get_simulated_results_graph(results):
start, stop, step = -2, 2, 0.1

fig, ax = plt.subplots()
for i, result in enumerate(results):
ax.plot(
np.arange(start, stop, step),
result,
".-",
label=f"Simulated Q{i}",
)
ax.set_ylim(-0.05, 1.05)
fig.legend()

return dump_fig_to_base64(fig)


def create_app():
app = FastAPI()

app.mount("/frontend", WSGIMiddleware(frontend.server))

app.state.simulation_request = SimulationRequest()

@app.post("/api/submit_qua_configuration")
async def submit_qua_configuration(
qua_configuration: Annotated[str, Body(embed=True)],
):
"""Submit QUA configuration. The dict must be serialized with dill and encoded
as a base64 string."""
app.state.simulation_request.qua_configuration = load_from_base64(
qua_configuration
)

@app.post("/api/submit_qua_program")
async def submit_qua_program(
qua_script: Annotated[Optional[str], Body(embed=True)] = None,
qua_program: Annotated[Optional[str], Body(embed=True)] = None,
):
"""Submit QUA script or program. The string or program_ast.Program must be
serialized with dill and encoded as a base64 string.

Warning:
If provided, `qua_script` is executed through `exec()`, which is a security
risk if the input is not trusted.
"""
if qua_script is not None and qua_program is not None:
raise ValueError(
"Only one of `qua_script` and `qua_program` can be provided"
)

if qua_script is not None:
app.state.simulation_request.qua_program = program_to_ast(
script_to_program(load_from_base64(qua_script))
)
elif qua_program is not None:
app.state.simulation_request.qua_program = load_from_base64(qua_program)
else:
raise ValueError("One of `qua_script` and `qua_program` must be provided")

@app.post("/api/submit_quantum_system")
async def submit_quantum_system(quantum_system: Annotated[bytes, Body(embed=True)]):
"""Submit quantum system. The object must be serialized with dill and encoded as
a base64 string."""
app.state.simulation_request.quantum_system = load_from_base64(quantum_system)

@app.post("/api/submit_channel_map")
async def submit_channel_map(channel_map: Annotated[bytes, Body(embed=True)]):
"""Submit quantum system. The dict must be serialized with dill and encoded as a
base64 string."""
app.state.simulation_request.channel_map = load_from_base64(channel_map)

@app.get("/api/simulate")
async def simulate(num_shots: int = 1000):
"""Simulate the system."""
# When this method returns, `self.result` is set.
request: SimulationRequest = app.state.simulation_request

try:
if not request.can_simulate:
raise ValueError("Missing data for simulation.")

# This is a breakdown of `simulate_program`, which gives an easier access to
# the schedules in `simulation`.
compiler = Compiler(config=request.qua_configuration)
simulation = compiler.compile(
request.qua_program,
request.channel_map,
TransmonPairBackendFromQUA(request.quantum_system, request.channel_map),
)
results = simulation.run(num_shots)
except Exception as e:
request.result = SimulationResult(
pulse_schedule_graph=None,
simulated_results=None,
simulated_results_graph=None,
error=e,
)
else:
request.result = SimulationResult(
pulse_schedule_graph=_get_pulse_schedule_graph(simulation.schedules),
simulated_results=results,
simulated_results_graph=_get_simulated_results_graph(results),
error=None,
)

@app.get("/api/status")
async def status() -> dict:
"""Return a dict with the result of the simulation, or an HTTP error if
something went wrong."""
result: SimulationResult = app.state.simulation_request.result

if result is None:
raise HTTPException(
status_code=http_status.HTTP_400_BAD_REQUEST,
detail="/simulate was not called",
)

if result.error is not None:
raise HTTPException(
status_code=http_status.HTTP_400_BAD_REQUEST,
detail=str(app.state.simulation_request.result.error),
)

return asdict(result)

@app.post("/api/reset")
async def reset():
"""Erase any request and simulation."""
app.state.simulation_request = SimulationRequest()

return app


app = create_app()
131 changes: 131 additions & 0 deletions quaqsim/api/frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import dash_bootstrap_components as dbc
from dash import Dash, html, callback, Input, Output, ctx
import requests


frontend = Dash(
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP],
requests_pathname_prefix="/frontend/",
)

controls = html.Div(
[
dbc.Button("Reload", color="primary", id="reload", class_name="me-3"),
dbc.Button("Reset", outline=True, color="danger", id="reset"),
]
)

frontend.layout = dbc.Container(
[
html.H1("qua-qsim frontend"),
html.Hr(),
dbc.Row(
dbc.Col(
controls,
),
),
dbc.Row(
dbc.Col(
html.Div("Message", id="message"),
),
class_name="mt-3",
id="message-row",
style={"display": "none"},
),
html.H2("Pulse schedule", className="mt-3"),
html.Hr(),
dbc.Row(
dbc.Col(
html.Img(id="pulse-schedule"),
width=6,
),
class_name="mt-3",
id="pulse-schedule-row",
style={"display": "none"},
),
html.H2("Simulated results", className="mt-3"),
html.Hr(),
dbc.Row(
dbc.Col(
html.Img(id="simulated-results"),
width=6,
),
class_name="mt-3",
id="simulated-results-row",
style={"display": "none"},
),
],
fluid=True,
)


@callback(
[
Output("message", "children"),
Output("message-row", "style"),
Output("pulse-schedule", "src"),
Output("pulse-schedule-row", "style"),
Output("simulated-results", "src"),
Output("simulated-results-row", "style"),
],
Input("reload", "n_clicks"),
Input("reset", "n_clicks"),
prevent_initial_call=True,
)
def update_simulated_results(reload, reset):
triggered_id = ctx.triggered_id

if triggered_id == "reload":
return _reload_simulated_results()
elif triggered_id == "reset":
return _reset_simulated_results()


def _reload_simulated_results():
response = requests.get("http://localhost:8000/api/status")

if response.status_code == 200:
pulse_schedule_graph = response.json()["pulse_schedule_graph"]
simulated_results_graph = response.json()["simulated_results_graph"]
return (
"",
{"display": "none"},
f"data:image/png;base64,{pulse_schedule_graph}",
{"display": "block"},
f"data:image/png;base64,{simulated_results_graph}",
{"display": "block"},
)
else:
error_message = response.json()["detail"]
return (
f"Error: “{error_message}”.",
{"display": "block"},
"",
{"display": "none"},
"",
{"display": "none"},
)


def _reset_simulated_results():
response = requests.post("http://localhost:8000/api/reset")

if response.status_code == 200:
return (
"Simulation was reset successfully.",
{"display": "block"},
"",
{"display": "none"},
"",
{"display": "none"},
)

return (
f"Error: {response.text}.",
{"display": "block"},
"",
{"display": "none"},
"",
{"display": "none"},
)
Loading