Skip to content

Commit

Permalink
change: Improved wires support (#55)
Browse files Browse the repository at this point in the history
*Description of changes:* In PL-core, it is possible for `wires` to be a list of strings and/or numbers:
```python
import pennylane as qml

dev = qml.device("default.qubit", wires=["a", -1])

@qml.qnode(dev)
def f(x):
    qml.RX(x, wires="a")
    qml.RZ(x, wires=-1)
    return qml.expval(qml.PauliZ("a") @ qml.PauliZ(-1))

f(0.3)
```
This PR extends support to the Braket device.

*Testing done:* A `test_wires` test has been added. Furthermore, we have added the `braket.local.qubit` device to the [PL testing matrix](https://github.com/PennyLaneAI/plugin-test-matrix#testing-matrix). Here, the device is tested against the [standard device test suite](https://github.com/PennyLaneAI/pennylane/tree/master/pennylane/devices/tests). It was the failures there within the `test_wires.py` file that made us notice this.
  • Loading branch information
trbromley authored Feb 2, 2021
1 parent 5b38101 commit d10add2
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 20 deletions.
32 changes: 20 additions & 12 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"""

# pylint: disable=invalid-name
from typing import FrozenSet, List, Optional, Sequence, Union
from typing import FrozenSet, Iterable, List, Optional, Sequence, Union

from braket.aws import AwsDevice, AwsDeviceType, AwsQuantumTask, AwsQuantumTaskBatch, AwsSession
from braket.circuits import Circuit, Instruction
Expand Down Expand Up @@ -61,7 +61,9 @@ class BraketQubitDevice(QubitDevice):
r"""Abstract Amazon Braket qubit device for PennyLane.
Args:
wires (int): the number of modes to initialize the device in.
wires (int or Iterable[Number, str]]): Number of subsystems represented by the device,
or iterable that contains unique labels for the subsystems as numbers
(i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``).
device (Device): The Amazon Braket device to use with PennyLane.
shots (int): Number of circuit evaluations or random samples included,
to estimate expectation values of observables. If this value is set to 0,
Expand All @@ -76,7 +78,7 @@ class BraketQubitDevice(QubitDevice):

def __init__(
self,
wires: int,
wires: Union[int, Iterable],
device: Device,
*,
shots: int,
Expand Down Expand Up @@ -124,7 +126,8 @@ def _pl_to_braket_circuit(self, circuit, **run_kwargs):
**run_kwargs,
)
for observable in circuit.observables:
braket_circuit.add_result_type(translate_result_type(observable))
dev_wires = self.map_wires(observable.wires).tolist()
braket_circuit.add_result_type(translate_result_type(observable, dev_wires))
return braket_circuit

def statistics(
Expand Down Expand Up @@ -183,7 +186,8 @@ def apply(
for operation in operations + rotations:
params = [p.numpy() if isinstance(p, np.tensor) else p for p in operation.parameters]
gate = translate_operation(operation, params)
ins = Instruction(gate, operation.wires.tolist())
dev_wires = self.map_wires(operation.wires).tolist()
ins = Instruction(gate, dev_wires)
circuit.add_instruction(ins)

unused = set(range(self.num_wires)) - {int(qubit) for qubit in circuit.qubits}
Expand All @@ -197,16 +201,18 @@ def apply(
def _run_task(self, circuit):
raise NotImplementedError("Need to implement task runner")

@staticmethod
def _get_statistic(braket_result, observable):
return braket_result.get_value_by_result_type(translate_result_type(observable))
def _get_statistic(self, braket_result, observable):
dev_wires = self.map_wires(observable.wires).tolist()
return braket_result.get_value_by_result_type(translate_result_type(observable, dev_wires))


class BraketAwsQubitDevice(BraketQubitDevice):
r"""Amazon Braket AwsDevice qubit device for PennyLane.
Args:
wires (int): the number of modes to initialize the device in.
wires (int or Iterable[Number, str]]): Number of subsystems represented by the device,
or iterable that contains unique labels for the subsystems as numbers
(i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``).
device_arn (str): The ARN identifying the ``AwsDevice`` to be used to
run circuits; The corresponding AwsDevice must support quantum
circuits via JAQCD. You can get device ARNs using ``AwsDevice.get_devices``,
Expand Down Expand Up @@ -245,7 +251,7 @@ class BraketAwsQubitDevice(BraketQubitDevice):

def __init__(
self,
wires: int,
wires: Union[int, Iterable],
device_arn: str,
s3_destination_folder: AwsSession.S3DestinationFolder,
*,
Expand Down Expand Up @@ -333,7 +339,9 @@ class BraketLocalQubitDevice(BraketQubitDevice):
r"""Amazon Braket LocalSimulator qubit device for PennyLane.
Args:
wires (int): the number of modes to initialize the device in.
wires (int or Iterable[Number, str]]): Number of subsystems represented by the device,
or iterable that contains unique labels for the subsystems as numbers
(i.e., ``[-1, 0, 2]``) or strings (``['ancilla', 'q1', 'q2']``).
backend (Union[str, BraketSimulator]): The name of the simulator backend or
the actual simulator instance to use for simulation. Defaults to the
``default`` simulator backend name.
Expand All @@ -349,7 +357,7 @@ class BraketLocalQubitDevice(BraketQubitDevice):

def __init__(
self,
wires: int,
wires: Union[int, Iterable],
backend: Union[str, BraketSimulator] = "default",
*,
shots: int = 0,
Expand Down
7 changes: 4 additions & 3 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.

from functools import singledispatch
from typing import FrozenSet
from typing import FrozenSet, List

import numpy as np
import pennylane as qml
Expand Down Expand Up @@ -217,17 +217,18 @@ def _(zz: ZZ, parameters):
return gates.ZZ(-phi) if zz.inverse else gates.ZZ(phi)


def translate_result_type(observable: Observable) -> ResultType:
def translate_result_type(observable: Observable, targets: List[int]) -> ResultType:
"""Translates a PennyLane ``Observable`` into the corresponding Braket ``ResultType``.
Args:
observable (Observable): The PennyLane ``Observable`` to translate
targets (List[int]): The target wires of the observable using a consecutive integer wire
ordering
Returns:
ResultType: The Braket result type corresponding to the given observable
"""
return_type = observable.return_type
targets = observable.wires.tolist()

if return_type is ObservableReturnTypes.Probability:
return Probability(targets)
Expand Down
18 changes: 17 additions & 1 deletion test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_apply_unsupported():
def test_apply_unwrap_tensor():
"""Test that apply() unwraps tensors from the PennyLane version of NumPy into standard NumPy
arrays (or floats)"""
dev = _device(wires=0)
dev = _device(wires=1)

a = anp.array(0.6) # array
b = np.array(0.5, requires_grad=True) # tensor
Expand Down Expand Up @@ -501,6 +501,22 @@ def test_invalid_device_type():
_device(wires=2, device_type="foo", shots=None)


def test_wires():
"""Test if the apply method supports custom wire labels"""

wires = ["A", 0, "B", -1]
dev = _device(wires=wires, device_type=AwsDeviceType.SIMULATOR, shots=0)

ops = [qml.RX(0.1, wires="A"), qml.CNOT(wires=[0, "B"]), qml.RY(0.3, wires=-1)]
target_wires = [[0], [1, 2], [3]]
circ = dev.apply(ops)

for op, targets in zip(circ.instructions, target_wires):
wires = op.target
for w, t in zip(wires, targets):
assert w == t


@patch.object(AwsDevice, "type", new_callable=mock.PropertyMock)
@patch.object(AwsDevice, "properties")
@patch.object(AwsDevice, "refresh_metadata", return_value=None)
Expand Down
8 changes: 4 additions & 4 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_translate_result_type_observable(return_type, braket_result):
Braket result using translate_result_type"""
obs = qml.Hadamard(0)
obs.return_type = return_type
braket_result_calculated = translate_result_type(obs)
braket_result_calculated = translate_result_type(obs, [0])

assert braket_result == braket_result_calculated

Expand All @@ -47,7 +47,7 @@ def test_translate_result_type_probs():
"""Tests if a PennyLane probability return type is successfully converted into a Braket
result using translate_result_type"""
mp = MeasurementProcess(ObservableReturnTypes.Probability, wires=Wires([0]))
braket_result_calculated = translate_result_type(mp)
braket_result_calculated = translate_result_type(mp, [0])

braket_result = Probability([0])

Expand All @@ -61,7 +61,7 @@ def test_translate_result_type_unsupported_return():
obs.return_type = None

with pytest.raises(NotImplementedError, match="Unsupported return type"):
translate_result_type(obs)
translate_result_type(obs, [0])


def test_translate_result_type_unsupported_obs():
Expand All @@ -70,4 +70,4 @@ def test_translate_result_type_unsupported_obs():
obs.return_type = None

with pytest.raises(TypeError, match="Unsupported observable"):
translate_result_type(obs)
translate_result_type(obs, [0])

0 comments on commit d10add2

Please sign in to comment.