Skip to content

Commit

Permalink
Fix: Incorrect pickle attribute extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
raynelfss committed Jun 28, 2024
1 parent ffa0a81 commit 9a7d9a0
Showing 1 changed file with 62 additions and 95 deletions.
157 changes: 62 additions & 95 deletions crates/circuit/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

use itertools::Itertools;

use pyo3::exceptions::PyTypeError;
use rustworkx_core::petgraph::csr::IndexType;
use rustworkx_core::petgraph::stable_graph::StableDiGraph;
use rustworkx_core::petgraph::visit::IntoEdgeReferences;
Expand Down Expand Up @@ -41,6 +42,7 @@ mod exceptions {
import_exception_bound! {qiskit.circuit.exceptions, CircuitError}
}
pub static PYDIGRAPH: ImportOnceCell = ImportOnceCell::new("rustworkx", "PyDiGraph");
pub static QUANTUMCIRCUIT: ImportOnceCell = ImportOnceCell::new("qiskit.circuit", "QuantumCircuit");

// Custom Structs

Expand Down Expand Up @@ -278,83 +280,60 @@ impl<'py> FromPyObject<'py> for GateOper {
#[derive(Debug, Clone)]
pub struct CircuitRep {
object: PyObject,
num_qubits: Option<u32>,
num_clbits: Option<u32>,
pub num_qubits: u32,
pub num_clbits: u32,
params: Option<SmallVec<[Param; 3]>>,
data: Option<Vec<CircuitInstruction>>,
// TODO: Have a valid implementation of CircuiData that's usable in Rust.
}

impl CircuitRep {
#[inline]
pub fn num_qubits(&mut self) -> u32 {
match &self.num_qubits {
Some(num_qubits) => *num_qubits,
None => {
let num_qubits = Python::with_gil(|py| -> PyResult<u32> {
self.object.getattr(py, "num_qubits")?.extract(py)
})
.unwrap_or_default();
self.num_qubits = Some(num_qubits);
num_qubits
}
}
}

#[inline]
pub fn num_clbits(&mut self) -> u32 {
match &self.num_clbits {
Some(num_clbits) => *num_clbits,
None => {
let num_clbits = Python::with_gil(|py| -> PyResult<u32> {
self.object.getattr(py, "num_clbits")?.extract(py)
})
.unwrap_or_default();
self.num_clbits = Some(num_clbits);
num_clbits
}
}
}

#[inline]
pub fn params(&mut self) -> &[Param] {
if self.params.is_some() {
pub fn parameters(&mut self) -> &[Param] {
if self.params.is_none() {
let params = Python::with_gil(|py| -> PyResult<SmallVec<[Param; 3]>> {
self.object
.bind(py)
.getattr("parameters")?
.getattr("data")?
.extract()
})
.unwrap_or_default();
self.params = Some(params);
return self.params.as_ref().unwrap();
}
let params = Python::with_gil(|py| -> PyResult<SmallVec<[Param; 3]>> {
self.object
.getattr(py, "params")?
.getattr(py, "data")?
.extract(py)
})
.unwrap_or_default();
self.params = Some(params);
self.params.as_ref().unwrap()
return self.params.as_ref().unwrap();
}

#[inline]
pub fn data(&mut self) -> &[CircuitInstruction] {
if self.data.is_some() {
if self.data.is_none() {
let data = Python::with_gil(|py| -> PyResult<Vec<CircuitInstruction>> {
self.object.bind(py).getattr("data")?.extract()
})
.unwrap_or_default();
self.data = Some(data);
return self.data.as_ref().unwrap();
}
let data = Python::with_gil(|py| -> PyResult<Vec<CircuitInstruction>> {
self.object.getattr(py, "data")?.extract(py)
})
.unwrap_or_default();
self.data = Some(data);
self.data.as_ref().unwrap()
return self.data.as_ref().unwrap();
}
}

impl FromPyObject<'_> for CircuitRep {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
Ok(Self {
object: ob.to_object(ob.py()),
num_qubits: None,
num_clbits: None,
params: None,
data: None,
})
if ob.is_instance(QUANTUMCIRCUIT.get_bound(ob.py()))? {
let num_qubits = ob.getattr("num_qubits")?.extract()?;
let num_clbits = ob.getattr("num_clbits")?.extract()?;
Ok(Self {
object: ob.into_py(ob.py()),
num_qubits,
num_clbits,
params: None,
data: None,
})
} else {
Err(PyTypeError::new_err(
"Provided object was not an instance of QuantumCircuit",
))
}
}
}

Expand Down Expand Up @@ -391,8 +370,8 @@ impl Default for CircuitRep {
fn default() -> Self {
Self {
object: Python::with_gil(|py| py.None()),
num_qubits: None,
num_clbits: None,
num_qubits: 0,
num_clbits: 0,
params: None,
data: None,
}
Expand Down Expand Up @@ -491,8 +470,7 @@ impl EquivalenceLibrary {
/// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each
/// equivalently implementing the given Gate.
fn set_entry(&mut self, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
let mut entry = entry;
match self.set_entry_native(&gate, &mut entry) {
match self.set_entry_native(gate, entry) {
Ok(_) => Ok(()),
Err(e) => Err(CircuitError::new_err(e.message)),
}
Expand Down Expand Up @@ -562,10 +540,10 @@ impl EquivalenceLibrary {
fn __getstate__(slf: PyRef<Self>) -> PyResult<Bound<'_, PyDict>> {
let ret = PyDict::new_bound(slf.py());
ret.set_item("rule_id", slf.rule_id)?;
let key_to_usize_node: HashMap<Key, usize> = HashMap::from_iter(
let key_to_usize_node: HashMap<(String, u32), usize> = HashMap::from_iter(
slf.key_to_node_index
.iter()
.map(|(key, val)| (key.clone(), val.index())),
.map(|(key, val)| ((key.name.to_string(), key.num_qubits), val.index())),
);
ret.set_item("key_to_node_index", key_to_usize_node.into_py(slf.py()))?;
let graph_nodes: Vec<NodeData> = slf._graph.node_weights().cloned().collect();
Expand All @@ -587,21 +565,6 @@ impl EquivalenceLibrary {

fn __setstate__(mut slf: PyRefMut<Self>, state: &Bound<'_, PyDict>) -> PyResult<()> {
slf.rule_id = state.get_item("rule_id")?.unwrap().extract()?;
state
.get_item("key_to_node_index")?
.unwrap()
.downcast::<PyDict>()?
.items()
.iter()
.filter_map(
|item| match (item.extract::<Key>().ok(), item.extract::<usize>().ok()) {
(Some(key), Some(value)) => Some((key, value)),
_ => None,
},
)
.for_each(|(key, value)| {
slf.key_to_node_index.insert(key, NodeIndex::new(value));
});
let graph_nodes: Vec<NodeData> = state.get_item("graph_nodes")?.unwrap().extract()?;
let graph_edges: Vec<(usize, usize, EdgeData)> =
state.get_item("graph_edges")?.unwrap().extract()?;
Expand All @@ -616,6 +579,13 @@ impl EquivalenceLibrary {
edge_weight,
);
}
slf.key_to_node_index = state
.get_item("key_to_node_index")?
.unwrap()
.extract::<HashMap<(String, u32), usize>>()?
.into_iter()
.map(|((name, num_qubits), val)| (Key::new(name, num_qubits), NodeIndex::new(val)))
.collect();
slf.graph = None;
Ok(())
}
Expand Down Expand Up @@ -656,8 +626,8 @@ impl EquivalenceLibrary {
gate: GateOper,
mut equivalent_circuit: CircuitRep,
) -> Result<(), EquivalenceError> {
raise_if_shape_mismatch(&gate, &mut equivalent_circuit)?;
raise_if_param_mismatch(&gate.params, equivalent_circuit.params())?;
raise_if_shape_mismatch(&gate, &equivalent_circuit)?;
raise_if_param_mismatch(&gate.params, equivalent_circuit.parameters())?;

let key: Key = Key {
name: gate.operation.name().to_string(),
Expand Down Expand Up @@ -711,12 +681,12 @@ impl EquivalenceLibrary {
/// equivalently implementing the given Gate.
pub fn set_entry_native(
&mut self,
gate: &GateOper,
entry: &mut Vec<CircuitRep>,
gate: GateOper,
mut entry: Vec<CircuitRep>,
) -> Result<(), EquivalenceError> {
for equiv in &mut *entry {
raise_if_shape_mismatch(gate, equiv)?;
raise_if_param_mismatch(&gate.params, equiv.params())?;
for equiv in entry.iter_mut() {
raise_if_shape_mismatch(&gate, equiv)?;
raise_if_param_mismatch(&gate.params, equiv.parameters())?;
}

let key = Key {
Expand Down Expand Up @@ -766,21 +736,18 @@ fn raise_if_param_mismatch(
Ok(())
}

fn raise_if_shape_mismatch(
gate: &GateOper,
circuit: &mut CircuitRep,
) -> Result<(), EquivalenceError> {
if gate.operation.num_qubits() != circuit.num_qubits()
|| gate.operation.num_clbits() != circuit.num_clbits()
fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> Result<(), EquivalenceError> {
if gate.operation.num_qubits() != circuit.num_qubits
|| gate.operation.num_clbits() != circuit.num_clbits
{
return Err(EquivalenceError::new_err(format!(
"Cannot add equivalence between circuit and gate \
of different shapes. Gate: {} qubits and {} clbits. \
Circuit: {} qubits and {} clbits.",
gate.operation.num_qubits(),
gate.operation.num_clbits(),
circuit.num_qubits(),
circuit.num_clbits()
circuit.num_qubits,
circuit.num_clbits
)));
}
Ok(())
Expand Down

0 comments on commit 9a7d9a0

Please sign in to comment.