Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cirq Decompose (#93) [Draft] #184

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions qbraid_qir/cirq/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

"""
import itertools
from typing import Iterable
from typing import Iterable, Any, Dict, Sequence, Type, Union, TYPE_CHECKING


import cirq
from cirq.transformers.analytical_decompositions import two_qubit_to_cz


from .exceptions import CirqConversionError
from .opsets import map_cirq_op_to_pyqir_callable
Expand All @@ -42,6 +45,71 @@ def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
raise CirqConversionError("Couldn't convert circuit to QIR gate set.")
return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops)))

class QIRGateset(cirq.TwoQubitCompilationTargetGateset):
def __init__(
self,
*,
atol: float = 1e-8,
) -> None:
# print(type(cirq.ops.H),)
self.ops = (
# cirq.ops.MeasurementGate,
cirq.ops.H,
cirq.ops.X,
cirq.ops.Y,
cirq.ops.Z,
cirq.ops.S, cirq.ops.T,
cirq.ops.SWAP, cirq.ops.CNOT, cirq.ops.CZ
)

super().__init__(
*self.ops,
name='QIRGateset',
preserve_moment_structure=True,
)
self.atol = atol

def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if not cirq.protocols.has_unitary(op):
return NotImplemented

circ = two_qubit_to_cz.two_qubit_matrix_to_cz_operations(
op.qubits[0],
op.qubits[1],
cirq.protocols.unitary(op),
allow_partial_czs=False,
atol=self.atol,
)

ops = []
for op in circ:
gate = op.gate
if isinstance(gate, cirq.ops.PhasedXPowGate):
ops.extend([
cirq.ops.ZPowGate(exponent=-gate.phase_exponent)(op.qubits[0]),
cirq.ops.XPowGate(exponent=gate.exponent)(op.qubits[0]),
cirq.ops.ZPowGate(exponent=gate.phase_exponent)(op.qubits[0])
])
else:
ops.append(op)

return []

def _value_equality_values_(self) -> Any:
return self.atol, self.allow_partial_czs, frozenset(self.additional_gates)

def _json_dict_(self) -> Dict[str, Any]:
d: Dict[str, Any] = {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs}
if self.additional_gates:
d['additional_gates'] = list(self.additional_gates)
return d

@classmethod
def _from_json_dict_(cls, atol, allow_partial_czs, additional_gates=(), **kwargs):
return cls(
atol=atol, allow_partial_czs=allow_partial_czs, additional_gates=additional_gates
)


def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Expand All @@ -53,21 +121,7 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
Returns:
cirq.Circuit: A new circuit with unsupported gates decomposed.
"""
new_circuit = cirq.Circuit()
for moment in circuit:
new_ops = []
for operation in moment:
if isinstance(operation, cirq.GateOperation):
decomposed_ops = list(_decompose_gate_op(operation))
new_ops.extend(decomposed_ops)
elif isinstance(operation, cirq.ClassicallyControlledOperation):
new_ops.append(operation)
else:
new_ops.append(operation)

new_circuit.append(new_ops)
return new_circuit

return cirq.optimize_for_target_gateset(circuit, gateset=QIRGateset(), ignore_failures=True, max_num_passes=1)

def preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Expand Down
Loading