diff --git a/qbraid_qir/cirq/passes.py b/qbraid_qir/cirq/passes.py index 2592ebf..9b499b4 100644 --- a/qbraid_qir/cirq/passes.py +++ b/qbraid_qir/cirq/passes.py @@ -13,9 +13,12 @@ """ import itertools -from typing import Iterable +from typing import Iterable, Any, Dict, Sequence, Type, Union, TYPE_CHECKING + import cirq +from cirq.transformers.analytical_decompositions import two_qubit_to_cz + from .exceptions import CirqConversionError from .opsets import map_cirq_op_to_pyqir_callable @@ -42,6 +45,71 @@ def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]: raise CirqConversionError("Couldn't convert circuit to QIR gate set.") return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops))) +class QIRGateset(cirq.TwoQubitCompilationTargetGateset): + def __init__( + self, + *, + atol: float = 1e-8, + ) -> None: + # print(type(cirq.ops.H),) + self.ops = ( + # cirq.ops.MeasurementGate, + cirq.ops.H, + cirq.ops.X, + cirq.ops.Y, + cirq.ops.Z, + cirq.ops.S, cirq.ops.T, + cirq.ops.SWAP, cirq.ops.CNOT, cirq.ops.CZ + ) + + super().__init__( + *self.ops, + name='QIRGateset', + preserve_moment_structure=True, + ) + self.atol = atol + + def _decompose_two_qubit_operation(self, op: 'cirq.Operation', _) -> 'cirq.OP_TREE': + if not cirq.protocols.has_unitary(op): + return NotImplemented + + circ = two_qubit_to_cz.two_qubit_matrix_to_cz_operations( + op.qubits[0], + op.qubits[1], + cirq.protocols.unitary(op), + allow_partial_czs=False, + atol=self.atol, + ) + + ops = [] + for op in circ: + gate = op.gate + if isinstance(gate, cirq.ops.PhasedXPowGate): + ops.extend([ + cirq.ops.ZPowGate(exponent=-gate.phase_exponent)(op.qubits[0]), + cirq.ops.XPowGate(exponent=gate.exponent)(op.qubits[0]), + cirq.ops.ZPowGate(exponent=gate.phase_exponent)(op.qubits[0]) + ]) + else: + ops.append(op) + + return [] + + def _value_equality_values_(self) -> Any: + return self.atol, self.allow_partial_czs, frozenset(self.additional_gates) + + def _json_dict_(self) -> Dict[str, Any]: + d: Dict[str, Any] = {'atol': self.atol, 'allow_partial_czs': self.allow_partial_czs} + if self.additional_gates: + d['additional_gates'] = list(self.additional_gates) + return d + + @classmethod + def _from_json_dict_(cls, atol, allow_partial_czs, additional_gates=(), **kwargs): + return cls( + atol=atol, allow_partial_czs=allow_partial_czs, additional_gates=additional_gates + ) + def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit: """ @@ -53,21 +121,7 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit: Returns: cirq.Circuit: A new circuit with unsupported gates decomposed. """ - new_circuit = cirq.Circuit() - for moment in circuit: - new_ops = [] - for operation in moment: - if isinstance(operation, cirq.GateOperation): - decomposed_ops = list(_decompose_gate_op(operation)) - new_ops.extend(decomposed_ops) - elif isinstance(operation, cirq.ClassicallyControlledOperation): - new_ops.append(operation) - else: - new_ops.append(operation) - - new_circuit.append(new_ops) - return new_circuit - + return cirq.optimize_for_target_gateset(circuit, gateset=QIRGateset(), ignore_failures=True, max_num_passes=1) def preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit: """