Skip to content

Commit

Permalink
Use rust gates for ConsolidateBlocks (#12704)
Browse files Browse the repository at this point in the history
* Use rust gates for ConsolidateBlocks

This commit moves to use rust gates for the ConsolidateBlocks transpiler
pass. Instead of generating the unitary matrices for the gates in a 2q
block Python side and passing that list to a rust function this commit
switches to passing a list of DAGOpNodes to the rust and then generating
the matrices inside the rust function directly. This is similar to what
was done in #12650 for Optimize1qGatesDecomposition. Besides being faster
to get the matrix for standard gates, it also reduces the eager
construction of Python gate objects which was a significant source of
overhead after #12459. To that end this builds on the thread of work in
the two PRs #12692 and #12701 which changed the access patterns for
other passes to minimize eager gate object construction.

* Add rust filter function for DAGCircuit.collect_2q_runs()

* Update crates/accelerate/src/convert_2q_block_matrix.rs

---------

Co-authored-by: John Lapeyre <[email protected]>
  • Loading branch information
mtreinish and jlapeyre authored Jul 8, 2024
1 parent 4867e8a commit 4fe9dbc
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 86 deletions.
104 changes: 96 additions & 8 deletions crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use pyo3::Python;

Expand All @@ -20,32 +22,84 @@ use numpy::ndarray::{aview2, Array2, ArrayView2};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use smallvec::SmallVec;

use qiskit_circuit::bit_data::BitData;
use qiskit_circuit::circuit_instruction::{operation_type_to_py, CircuitInstruction};
use qiskit_circuit::dag_node::DAGOpNode;
use qiskit_circuit::gate_matrix::ONE_QUBIT_IDENTITY;
use qiskit_circuit::imports::QI_OPERATOR;
use qiskit_circuit::operations::{Operation, OperationType};

use crate::QiskitError;

fn get_matrix_from_inst<'py>(
py: Python<'py>,
inst: &'py CircuitInstruction,
) -> PyResult<Array2<Complex64>> {
match inst.operation.matrix(&inst.params) {
Some(mat) => Ok(mat),
None => match inst.operation {
OperationType::Standard(_) => Err(QiskitError::new_err(
"Parameterized gates can't be consolidated",
)),
OperationType::Gate(_) => Ok(QI_OPERATOR
.get_bound(py)
.call1((operation_type_to_py(py, inst)?,))?
.getattr(intern!(py, "data"))?
.extract::<PyReadonlyArray2<Complex64>>()?
.as_array()
.to_owned()),
_ => unreachable!("Only called for unitary ops"),
},
}
}

/// Return the matrix Operator resulting from a block of Instructions.
#[pyfunction]
#[pyo3(text_signature = "(op_list, /")]
pub fn blocks_to_matrix(
py: Python,
op_list: Vec<(PyReadonlyArray2<Complex64>, SmallVec<[u8; 2]>)>,
op_list: Vec<PyRef<DAGOpNode>>,
block_index_map_dict: &Bound<PyDict>,
) -> PyResult<Py<PyArray2<Complex64>>> {
// Build a BitData in block_index_map_dict order. block_index_map_dict is a dict of bits to
// indices mapping the order of the qargs in the block. There should only be 2 entries since
// there are only 2 qargs here (e.g. `{Qubit(): 0, Qubit(): 1}`) so we need to ensure that
// we added the qubits to bit data in the correct index order.
let mut index_map: Vec<PyObject> = (0..block_index_map_dict.len()).map(|_| py.None()).collect();
for bit_tuple in block_index_map_dict.items() {
let (bit, index): (PyObject, usize) = bit_tuple.extract()?;
index_map[index] = bit;
}
let mut bit_map: BitData<u32> = BitData::new(py, "qargs".to_string());
for bit in index_map {
bit_map.add(py, bit.bind(py), true)?;
}
let identity = aview2(&ONE_QUBIT_IDENTITY);
let input_matrix = op_list[0].0.as_array();
let mut matrix: Array2<Complex64> = match op_list[0].1.as_slice() {
let first_node = &op_list[0];
let input_matrix = get_matrix_from_inst(py, &first_node.instruction)?;
let mut matrix: Array2<Complex64> = match bit_map
.map_bits(first_node.instruction.qubits.bind(py).iter())?
.collect::<Vec<_>>()
.as_slice()
{
[0] => kron(&identity, &input_matrix),
[1] => kron(&input_matrix, &identity),
[0, 1] => input_matrix.to_owned(),
[1, 0] => change_basis(input_matrix),
[0, 1] => input_matrix,
[1, 0] => change_basis(input_matrix.view()),
[] => Array2::eye(4),
_ => unreachable!(),
};
for (op_matrix, q_list) in op_list.into_iter().skip(1) {
let op_matrix = op_matrix.as_array();
for node in op_list.into_iter().skip(1) {
let op_matrix = get_matrix_from_inst(py, &node.instruction)?;
let q_list = bit_map
.map_bits(node.instruction.qubits.bind(py).iter())?
.map(|x| x as u8)
.collect::<SmallVec<[u8; 2]>>();

let result = match q_list.as_slice() {
[0] => Some(kron(&identity, &op_matrix)),
[1] => Some(kron(&op_matrix, &identity)),
[1, 0] => Some(change_basis(op_matrix)),
[1, 0] => Some(change_basis(op_matrix.view())),
[] => Some(Array2::eye(4)),
_ => None,
};
Expand All @@ -71,8 +125,42 @@ pub fn change_basis(matrix: ArrayView2<Complex64>) -> Array2<Complex64> {
trans_matrix
}

#[pyfunction]
pub fn collect_2q_blocks_filter(node: &Bound<PyAny>) -> Option<bool> {
match node.downcast::<DAGOpNode>() {
Ok(bound_node) => {
let node = bound_node.borrow();
match &node.instruction.operation {
OperationType::Standard(gate) => Some(
gate.num_qubits() <= 2
&& node
.instruction
.extra_attrs
.as_ref()
.and_then(|attrs| attrs.condition.as_ref())
.is_none()
&& !node.is_parameterized(),
),
OperationType::Gate(gate) => Some(
gate.num_qubits() <= 2
&& node
.instruction
.extra_attrs
.as_ref()
.and_then(|attrs| attrs.condition.as_ref())
.is_none()
&& !node.is_parameterized(),
),
_ => Some(false),
}
}
Err(_) => None,
}
}

#[pymodule]
pub fn convert_2q_block_matrix(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(blocks_to_matrix))?;
m.add_wrapped(wrap_pyfunction!(collect_2q_blocks_filter))?;
Ok(())
}
8 changes: 6 additions & 2 deletions crates/circuit/src/bit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl PartialEq for BitAsKey {
impl Eq for BitAsKey {}

#[derive(Clone, Debug)]
pub(crate) struct BitData<T> {
pub struct BitData<T> {
/// The public field name (i.e. `qubits` or `clbits`).
description: String,
/// Registered Python bits.
Expand All @@ -81,7 +81,7 @@ pub(crate) struct BitData<T> {
cached: Py<PyList>,
}

pub(crate) struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);
pub struct BitNotFoundError<'py>(pub(crate) Bound<'py, PyAny>);

impl<'py> From<BitNotFoundError<'py>> for PyErr {
fn from(error: BitNotFoundError) -> Self {
Expand Down Expand Up @@ -111,6 +111,10 @@ where
self.bits.len()
}

pub fn is_empty(&self) -> bool {
self.bits.is_empty()
}

/// Gets a reference to the underlying vector of Python bits.
#[inline]
pub fn bits(&self) -> &Vec<PyObject> {
Expand Down
1 change: 1 addition & 0 deletions crates/circuit/src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub static SINGLETON_CONTROLLED_GATE: ImportOnceCell =
pub static CONTROLLED_GATE: ImportOnceCell =
ImportOnceCell::new("qiskit.circuit", "ControlledGate");
pub static DEEPCOPY: ImportOnceCell = ImportOnceCell::new("copy", "deepcopy");
pub static QI_OPERATOR: ImportOnceCell = ImportOnceCell::new("qiskit.quantum_info", "Operator");
pub static WARNINGS_WARN: ImportOnceCell = ImportOnceCell::new("warnings", "warn");

/// A mapping from the enum variant in crate::operations::StandardGate to the python
Expand Down
2 changes: 1 addition & 1 deletion crates/circuit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

pub mod bit_data;
pub mod circuit_data;
pub mod circuit_instruction;
pub mod dag_node;
Expand All @@ -20,7 +21,6 @@ pub mod parameter_table;
pub mod slice;
pub mod util;

mod bit_data;
mod interner;

use pyo3::prelude::*;
Expand Down
24 changes: 5 additions & 19 deletions qiskit/dagcircuit/dagcircuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from qiskit.circuit.bit import Bit
from qiskit.pulse import Schedule
from qiskit._accelerate.euler_one_qubit_decomposer import collect_1q_runs_filter
from qiskit._accelerate.convert_2q_block_matrix import collect_2q_blocks_filter

BitLocations = namedtuple("BitLocations", ("index", "registers"))
# The allowable arguments to :meth:`DAGCircuit.copy_empty_like`'s ``vars_mode``.
Expand Down Expand Up @@ -1348,9 +1349,9 @@ def replace_block_with_op(
for nd in node_block:
block_qargs |= set(nd.qargs)
block_cargs |= set(nd.cargs)
if (condition := getattr(nd.op, "condition", None)) is not None:
if (condition := getattr(nd, "condition", None)) is not None:
block_cargs.update(condition_resources(condition).clbits)
elif isinstance(nd.op, SwitchCaseOp):
elif nd.name in CONTROL_FLOW_OP_NAMES and isinstance(nd.op, SwitchCaseOp):
if isinstance(nd.op.target, Clbit):
block_cargs.add(nd.op.target)
elif isinstance(nd.op.target, ClassicalRegister):
Expand Down Expand Up @@ -2158,28 +2159,13 @@ def collect_1q_runs(self) -> list[list[DAGOpNode]]:
def collect_2q_runs(self):
"""Return a set of non-conditional runs of 2q "op" nodes."""

to_qid = {}
for i, qubit in enumerate(self.qubits):
to_qid[qubit] = i

def filter_fn(node):
if isinstance(node, DAGOpNode):
return (
isinstance(node.op, Gate)
and len(node.qargs) <= 2
and not getattr(node.op, "condition", None)
and not node.op.is_parameterized()
)
else:
return None

def color_fn(edge):
if isinstance(edge, Qubit):
return to_qid[edge]
return self.find_bit(edge).index
else:
return None

return rx.collect_bicolor_runs(self._multi_graph, filter_fn, color_fn)
return rx.collect_bicolor_runs(self._multi_graph, collect_2q_blocks_filter, color_fn)

def nodes_on_wire(self, wire, only_ops=False):
"""
Expand Down
19 changes: 11 additions & 8 deletions qiskit/transpiler/passes/optimization/consolidate_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from qiskit.circuit.library.generalized_gates.unitary import UnitaryGate
from qiskit.circuit.library.standard_gates import CXGate
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.circuit.controlflow import ControlFlowOp
from qiskit.transpiler.passmanager import PassManager
from qiskit.transpiler.passes.synthesis import unitary_synthesis
from qiskit.transpiler.passes.utils import _block_to_matrix
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit._accelerate.convert_2q_block_matrix import blocks_to_matrix

from .collect_1q_runs import Collect1qRuns
from .collect_2q_blocks import Collect2qBlocks

Expand Down Expand Up @@ -105,14 +106,14 @@ def run(self, dag):
block_cargs = set()
for nd in block:
block_qargs |= set(nd.qargs)
if isinstance(nd, DAGOpNode) and getattr(nd.op, "condition", None):
block_cargs |= set(getattr(nd.op, "condition", None)[0])
if isinstance(nd, DAGOpNode) and getattr(nd, "condition", None):
block_cargs |= set(getattr(nd, "condition", None)[0])
all_block_gates.add(nd)
block_index_map = self._block_qargs_to_indices(dag, block_qargs)
for nd in block:
if nd.op.name == basis_gate_name:
if nd.name == basis_gate_name:
basis_count += 1
if self._check_not_in_basis(dag, nd.op.name, nd.qargs):
if self._check_not_in_basis(dag, nd.name, nd.qargs):
outside_basis = True
if len(block_qargs) > 2:
q = QuantumRegister(len(block_qargs))
Expand All @@ -124,7 +125,7 @@ def run(self, dag):
qc.append(nd.op, [q[block_index_map[i]] for i in nd.qargs])
unitary = UnitaryGate(Operator(qc), check_input=False)
else:
matrix = _block_to_matrix(block, block_index_map)
matrix = blocks_to_matrix(block, block_index_map)
unitary = UnitaryGate(matrix, check_input=False)

max_2q_depth = 20 # If depth > 20, there will be 1q gates to consolidate.
Expand Down Expand Up @@ -192,7 +193,9 @@ def _handle_control_flow_ops(self, dag):
pass_manager.append(Collect2qBlocks())

pass_manager.append(self)
for node in dag.op_nodes(ControlFlowOp):
for node in dag.op_nodes():
if node.name not in CONTROL_FLOW_OP_NAMES:
continue
node.op = node.op.replace_blocks(pass_manager.run(block) for block in node.op.blocks)
return dag

Expand Down
1 change: 0 additions & 1 deletion qiskit/transpiler/passes/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@

# Utility functions
from . import control_flow
from .block_to_matrix import _block_to_matrix
47 changes: 0 additions & 47 deletions qiskit/transpiler/passes/utils/block_to_matrix.py

This file was deleted.

0 comments on commit 4fe9dbc

Please sign in to comment.