Skip to content

Commit

Permalink
feat: Introduce run_multiple method (#264)
Browse files Browse the repository at this point in the history
Introduces `run_multiple` method to `BraketSimulator` to allow backends to leverage their own batching implementations. Will next publish SDK PR to make use of this interface.
  • Loading branch information
speller26 authored Jun 27, 2024
1 parent dc68323 commit e93e551
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def run(
Args:
circuit_ir (Union[OpenQASMProgram, JaqcdProgram]): Circuit specification.
qubit_count (int, jaqcd-only): Number of qubits.
shots (int, optional): The number of shots to simulate. Default is 0, which
performs a full analytical simulation.
batch_size (int, optional): The size of the circuit partitions to contract,
Expand Down Expand Up @@ -632,7 +631,7 @@ def run_openqasm(
def run_jaqcd(
self,
circuit_ir: JaqcdProgram,
qubit_count: int,
qubit_count: Any = None,
shots: int = 0,
*,
batch_size: int = 1,
Expand All @@ -642,7 +641,7 @@ def run_jaqcd(
Args:
circuit_ir (Program): ir representation of a braket circuit specifying the
instructions to execute.
qubit_count (int): Unused parameter; in signature for backwards-compatibility
qubit_count (Any): Unused parameter; in signature for backwards-compatibility
shots (int): The number of times to run the circuit.
batch_size (int): The size of the circuit partitions to contract,
if applying multiple gates at a time is desired; see `StateVectorSimulation`.
Expand All @@ -657,6 +656,10 @@ def run_jaqcd(
as a result type when shots=0. Or, if StateVector and Amplitude result types
are requested when shots>0.
"""
if qubit_count is not None:
warnings.warn(
f"qubit_count is deprecated for {type(self).__name__} and can be set to None"
)
self._validate_ir_results_compatibility(
circuit_ir.results,
device_action_type=DeviceActionType.JAQCD,
Expand Down
41 changes: 39 additions & 2 deletions src/braket/simulator/braket_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# language governing permissions and limitations under the License.

from abc import ABC, abstractmethod
from typing import Union
from collections.abc import Sequence
from multiprocessing import Pool
from os import cpu_count
from typing import Optional, Union

from braket.device_schema import DeviceCapabilities
from braket.ir.ahs import Program as AHSProgram
Expand Down Expand Up @@ -49,7 +52,7 @@ def run(
Run the task specified by the given IR.
Extra arguments will contain any additional information necessary to run the task,
such as number of qubits.
such as the extra parameters for AHS simulations.
Args:
ir (Union[OQ3Program, AHSProgram, JaqcdProgram]): The IR representation of the program
Expand All @@ -59,6 +62,40 @@ def run(
representing the results of the simulation.
"""

def run_multiple(
self,
programs: Sequence[Union[OQ3Program, AHSProgram, JaqcdProgram]],
max_parallel: Optional[int] = None,
*args,
**kwargs,
) -> list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]:
"""
Run the tasks specified by the given IR programs.
Extra arguments will contain any additional information necessary to run the tasks,
such as the extra parameters for AHS simulations.
Args:
programs (Sequence[Union[OQ3Program, AHSProgram, JaqcdProgram]]): The IR representations
of the programs
max_parallel (Optional[int]): The maximum number of programs to run in parallel.
Default is the number of logical CPUs.
Returns:
list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]: A list of
result objects, with the ith object being the result of the ith program.
"""
max_parallel = max_parallel or cpu_count()
with Pool(min(max_parallel, len(programs))) as pool:
param_list = [(program, args, kwargs) for program in programs]
results = pool.starmap(self._run_wrapped, param_list)
return results

def _run_wrapped(
self, ir: Union[OQ3Program, AHSProgram, JaqcdProgram], args, kwargs
): # pragma: no cover
return self.run(ir, *args, **kwargs)

@property
@abstractmethod
def properties(self) -> DeviceCapabilities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def test_simulator_run_bell_pair(bell_ir, caplog):
simulator = DensityMatrixSimulator()
shots_count = 10000
if isinstance(bell_ir, JaqcdProgram):
result = simulator.run(bell_ir, qubit_count=2, shots=shots_count)
# Ignore qubit_count
result = simulator.run(bell_ir, shots=shots_count)
else:
result = simulator.run(bell_ir, shots=shots_count)

Expand Down Expand Up @@ -392,7 +393,6 @@ def test_properties():
"deviceParameters": GateModelSimulatorDeviceParameters.schema(),
}
)
print(expected_properties)
assert simulator.properties == expected_properties


Expand Down Expand Up @@ -864,3 +864,23 @@ def test_noncontiguous_qubits_openqasm(qasm_file_name):
(np.allclose(measurement, [0, 0]) or np.allclose(measurement, [1, 1]))
for measurement in measurements
)


def test_run_multiple():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
{gate} q[0];
#pragma braket result density_matrix
"""
)
for gate in ["h", "z", "x"]
]
simulator = DensityMatrixSimulator()
results = simulator.run_multiple(payloads, shots=0)
assert np.allclose(results[0].resultTypes[0].value, np.array([[0.5, 0.5], [0.5, 0.5]]))
assert np.allclose(results[1].resultTypes[0].value, np.array([[1, 0], [0, 0]]))
assert np.allclose(results[2].resultTypes[0].value, np.array([[0, 0], [0, 1]]))
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_simulator_run_bell_pair(bell_ir, batch_size, caplog):
shots_count = 10000
if isinstance(bell_ir, JaqcdProgram):
# Ignore qubit_count
result = simulator.run(bell_ir, qubit_count=10, shots=shots_count, batch_size=batch_size)
result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size)
else:
result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size)

Expand Down Expand Up @@ -1425,3 +1425,23 @@ def test_noncontiguous_qubits_jaqcd_multiple_targets():

assert result.measuredQubits == [0, 1]
assert result.resultTypes[0].value == -1


def test_run_multiple():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
{gate} q[0];
#pragma braket result state_vector
"""
)
for gate in ["h", "z", "x"]
]
simulator = StateVectorSimulator()
results = simulator.run_multiple(payloads, shots=0)
assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2))
assert np.allclose(results[1].resultTypes[0].value, np.array([1, 0]))
assert np.allclose(results[2].resultTypes[0].value, np.array([0, 1]))

0 comments on commit e93e551

Please sign in to comment.