Skip to content

Commit

Permalink
Merge branch 'main' into rusty-dag-add-from-iter
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss authored Sep 5, 2024
2 parents 2ee5b40 + 86c63eb commit 2b765b7
Show file tree
Hide file tree
Showing 39 changed files with 280 additions and 80 deletions.
192 changes: 192 additions & 0 deletions crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::exceptions::PyValueError;
use pyo3::prelude::PyModule;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use qiskit_circuit::Qubit;

use crate::commutation_checker::CommutationChecker;
use hashbrown::HashMap;
use pyo3::prelude::*;

use pyo3::types::{PyDict, PyList};
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use rustworkx_core::petgraph::stable_graph::NodeIndex;

// Custom types to store the commutation sets and node indices,
// see the docstring below for more information.
type CommutationSet = HashMap<Wire, Vec<Vec<NodeIndex>>>;
type NodeIndices = HashMap<(NodeIndex, Wire), usize>;

// the maximum number of qubits we check commutativity for
const MAX_NUM_QUBITS: u32 = 3;

/// Compute the commutation sets for a given DAG.
///
/// We return two HashMaps:
/// * {wire: commutation_sets}: For each wire, we keep a vector of index sets, where each index
/// set contains mutually commuting nodes. Note that these include the input and output nodes
/// which do not commute with anything.
/// * {(node, wire): index}: For each (node, wire) pair we store the index indicating in which
/// commutation set the node appears on a given wire.
///
/// For example, if we have a circuit
///
/// |0> -- X -- SX -- Z (out)
/// 0 2 3 4 1 <-- node indices including input (0) and output (1) nodes
///
/// Then we would have
///
/// commutation_set = {0: [[0], [2, 3], [4], [1]]}
/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2}
///
fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
) -> PyResult<(CommutationSet, NodeIndices)> {
let mut commutation_set: CommutationSet = HashMap::new();
let mut node_indices: NodeIndices = HashMap::new();

for qubit in 0..dag.num_qubits() {
let wire = Wire::Qubit(Qubit(qubit as u32));

for current_gate_idx in dag.nodes_on_wire(py, &wire, false) {
// get the commutation set associated with the current wire, or create a new
// index set containing the current gate
let commutation_entry = commutation_set
.entry(wire.clone())
.or_insert_with(|| vec![vec![current_gate_idx]]);

// we can unwrap as we know the commutation entry has at least one element
let last = commutation_entry.last_mut().unwrap();

// if the current gate index is not in the set, check whether it commutes with
// the previous nodes -- if yes, add it to the commutation set
if !last.contains(&current_gate_idx) {
let mut all_commute = true;

for prev_gate_idx in last.iter() {
// if the node is an input/output node, they do not commute, so we only
// continue if the nodes are operation nodes
if let (NodeType::Operation(packed_inst0), NodeType::Operation(packed_inst1)) =
(&dag.dag[current_gate_idx], &dag.dag[*prev_gate_idx])
{
let op1 = packed_inst0.op.view();
let op2 = packed_inst1.op.view();
let params1 = packed_inst0.params_view();
let params2 = packed_inst1.params_view();
let qargs1 = dag.get_qargs(packed_inst0.qubits);
let qargs2 = dag.get_qargs(packed_inst1.qubits);
let cargs1 = dag.get_cargs(packed_inst0.clbits);
let cargs2 = dag.get_cargs(packed_inst1.clbits);

all_commute = commutation_checker.commute_inner(
py,
&op1,
params1,
packed_inst0.extra_attrs.as_deref(),
qargs1,
cargs1,
&op2,
params2,
packed_inst1.extra_attrs.as_deref(),
qargs2,
cargs2,
MAX_NUM_QUBITS,
)?;
if !all_commute {
break;
}
} else {
all_commute = false;
break;
}
}

if all_commute {
// all commute, add to current list
last.push(current_gate_idx);
} else {
// does not commute, create new list
commutation_entry.push(vec![current_gate_idx]);
}
}

node_indices.insert(
(current_gate_idx, wire.clone()),
commutation_entry.len() - 1,
);
}
}

Ok((commutation_set, node_indices))
}

#[pyfunction]
#[pyo3(signature = (dag, commutation_checker))]
pub(crate) fn analyze_commutations(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
) -> PyResult<Py<PyDict>> {
// This returns two HashMaps:
// * The commuting nodes per wire: {wire: [commuting_nodes_1, commuting_nodes_2, ...]}
// * The index in which commutation set a given node is located on a wire: {(node, wire): index}
// The Python dict will store both of these dictionaries in one.
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;

let out_dict = PyDict::new_bound(py);

// First set the {wire: [commuting_nodes_1, ...]} bit
for (wire, commutations) in commutation_set {
// we know all wires are of type Wire::Qubit, since in analyze_commutations_inner
// we only iterater over the qubits
let py_wire = match wire {
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py),
_ => return Err(PyValueError::new_err("Unexpected wire type.")),
};

out_dict.set_item(
py_wire,
PyList::new_bound(
py,
commutations.iter().map(|inner| {
PyList::new_bound(
py,
inner
.iter()
.map(|node_index| dag.get_node(py, *node_index).unwrap()),
)
}),
),
)?;
}

// Then we add the {(node, wire): index} dictionary
for ((node_index, wire), index) in node_indices {
let py_wire = match wire {
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py),
_ => return Err(PyValueError::new_err("Unexpected wire type.")),
};
out_dict.set_item((dag.get_node(py, node_index)?, py_wire), index)?;
}

Ok(out_dict.unbind())
}

#[pymodule]
pub fn commutation_analysis(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(analyze_commutations))?;
Ok(())
}
4 changes: 2 additions & 2 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ where
/// lookups. It's not meant to be a public facing Python object though and only used
/// internally by the Python class.
#[pyclass(module = "qiskit._accelerate.commutation_checker")]
struct CommutationChecker {
pub struct CommutationChecker {
library: CommutationLibrary,
cache_max_entries: usize,
cache: HashMap<(String, String), CommutationCacheEntry>,
Expand Down Expand Up @@ -227,7 +227,7 @@ impl CommutationChecker {

impl CommutationChecker {
#[allow(clippy::too_many_arguments)]
fn commute_inner(
pub fn commute_inner(
&mut self,
py: Python,
op1: &OperationRef,
Expand Down
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::env;
use pyo3::import_exception;

pub mod circuit_library;
pub mod commutation_analysis;
pub mod commutation_checker;
pub mod convert_2q_block_matrix;
pub mod dense_layout;
Expand Down
2 changes: 1 addition & 1 deletion crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5528,7 +5528,7 @@ impl DAGCircuit {
/// Get the nodes on the given wire.
///
/// Note: result is empty if the wire is not in the DAG.
fn nodes_on_wire(&self, py: Python, wire: &Wire, only_ops: bool) -> Vec<NodeIndex> {
pub fn nodes_on_wire(&self, py: Python, wire: &Wire, only_ops: bool) -> Vec<NodeIndex> {
let mut nodes = Vec::new();
let mut current_node = match wire {
Wire::Qubit(qubit) => self.qubit_io_map.get(qubit.0 as usize).map(|x| x[0]),
Expand Down
33 changes: 26 additions & 7 deletions crates/circuit/src/interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ where
/// `Interner::get_default` to reliably work correctly without a hash lookup (though ideally
/// we'd just use specialisation to do that).
pub fn new() -> Self {
let mut set = IndexSet::with_capacity_and_hasher(1, Default::default());
set.insert(Default::default());
Self(set)
Self::with_capacity(1)
}

/// Retrieve the key corresponding to the default store, without any hash or equality lookup.
Expand All @@ -126,11 +124,14 @@ where
}
}

/// Create an interner with enough space to hold `capacity` entries.
///
/// Note that the default item of the interner is always allocated and given a key immediately,
/// which will use one slot of the capacity.
pub fn with_capacity(capacity: usize) -> Self {
Self(IndexSet::with_capacity_and_hasher(
capacity,
::ahash::RandomState::new(),
))
let mut set = IndexSet::with_capacity_and_hasher(capacity, ::ahash::RandomState::new());
set.insert(Default::default());
Self(set)
}
}

Expand Down Expand Up @@ -196,3 +197,21 @@ where
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn default_key_exists() {
let mut interner = Interner::<[u32]>::new();
assert_eq!(interner.get_default(), interner.get_default());
assert_eq!(interner.get(interner.get_default()), &[]);
assert_eq!(interner.insert_owned(Vec::new()), interner.get_default());
assert_eq!(interner.insert(&[]), interner.get_default());

let capacity = Interner::<str>::with_capacity(4);
assert_eq!(capacity.get_default(), capacity.get_default());
assert_eq!(capacity.get(capacity.get_default()), "");
}
}
19 changes: 10 additions & 9 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
use pyo3::prelude::*;

use qiskit_accelerate::{
circuit_library::circuit_library, commutation_checker::commutation_checker,
convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout,
error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val,
sparse_pauli_op::sparse_pauli_op, star_prerouting::star_prerouting,
stochastic_swap::stochastic_swap, synthesis::synthesis, target_transpiler::target,
two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils,
vf2_layout::vf2_layout,
circuit_library::circuit_library, commutation_analysis::commutation_analysis,
commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
utils::utils, vf2_layout::vf2_layout,
};

#[inline(always)]
Expand Down Expand Up @@ -62,5 +62,6 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, utils, "utils")?;
add_submodule(m, vf2_layout, "vf2_layout")?;
add_submodule(m, commutation_checker, "commutation_checker")?;
add_submodule(m, commutation_analysis, "commutation_analysis")?;
Ok(())
}
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
sys.modules["qiskit._accelerate.synthesis.linear"] = _accelerate.synthesis.linear
sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.clifford
sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker
sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase

from qiskit.exceptions import QiskitError, MissingOptionalLibraryError
Expand Down
52 changes: 3 additions & 49 deletions qiskit/transpiler/passes/optimization/commutation_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@

"""Analysis pass to find commutation relations between DAG nodes."""

from collections import defaultdict

from qiskit.circuit.commutation_library import SessionCommutationChecker as scc
from qiskit.dagcircuit import DAGOpNode
from qiskit.transpiler.basepasses import AnalysisPass
from qiskit._accelerate.commutation_analysis import analyze_commutations


class CommutationAnalysis(AnalysisPass):
Expand All @@ -33,6 +31,7 @@ def __init__(self, *, _commutation_checker=None):
# do not care about commutations of all gates, but just a subset
if _commutation_checker is None:
_commutation_checker = scc

self.comm_checker = _commutation_checker

def run(self, dag):
Expand All @@ -42,49 +41,4 @@ def run(self, dag):
into the ``property_set``.
"""
# Initiate the commutation set
self.property_set["commutation_set"] = defaultdict(list)

# Build a dictionary to keep track of the gates on each qubit
# The key with format (wire) will store the lists of commutation sets
# The key with format (node, wire) will store the index of the commutation set
# on the specified wire, thus, for example:
# self.property_set['commutation_set'][wire][(node, wire)] will give the
# commutation set that contains node.

for wire in dag.qubits:
self.property_set["commutation_set"][wire] = []

# Add edges to the dictionary for each qubit
for node in dag.topological_op_nodes():
for _, _, edge_wire in dag.edges(node):
self.property_set["commutation_set"][(node, edge_wire)] = -1

# Construct the commutation set
for wire in dag.qubits:

for current_gate in dag.nodes_on_wire(wire):

current_comm_set = self.property_set["commutation_set"][wire]
if not current_comm_set:
current_comm_set.append([current_gate])

if current_gate not in current_comm_set[-1]:
does_commute = True

# Check if the current gate commutes with all the gates in the current block
for prev_gate in current_comm_set[-1]:
does_commute = (
isinstance(current_gate, DAGOpNode)
and isinstance(prev_gate, DAGOpNode)
and self.comm_checker.commute_nodes(current_gate, prev_gate)
)
if not does_commute:
break

if does_commute:
current_comm_set[-1].append(current_gate)
else:
current_comm_set.append([current_gate])

temp_len = len(current_comm_set)
self.property_set["commutation_set"][(current_gate, wire)] = temp_len - 1
self.property_set["commutation_set"] = analyze_commutations(dag, self.comm_checker.cc)
Loading

0 comments on commit 2b765b7

Please sign in to comment.