Skip to content

Commit

Permalink
Fix: Review comments and ownership issues.
Browse files Browse the repository at this point in the history
- Add `from_operation` constructor for `Key`.
- Made `py_has_entry()` private, but kept its main method public.
- Made `set_entry` more rust friendly.
- Modify `add_equivalence` to accept a slice of `Param` and use `Into` to convert it into a `SmallVec` instance.
  • Loading branch information
raynelfss committed Sep 20, 2024
1 parent 1a21204 commit b7c6316
Showing 1 changed file with 76 additions and 62 deletions.
138 changes: 76 additions & 62 deletions crates/accelerate/src/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ impl Key {
self.gt(other)
}
}
impl Key {
fn from_operation(operation: &PackedOperation) -> Self {
let op_ref: OperationRef = operation.view();
Key {
name: op_ref.name().to_string(),
num_qubits: op_ref.num_qubits(),
}
}
}

impl Display for Key {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -268,6 +277,8 @@ impl Display for EdgeData {
}

/// Enum that helps extract the Operation and Parameters on a Gate.
/// It is highly derivative of `PackedOperation` while also tracking the specific
/// parameter objects.
#[derive(Debug, Clone)]
pub struct GateOper {
operation: PackedOperation,
Expand All @@ -284,8 +295,11 @@ impl<'py> FromPyObject<'py> for GateOper {
}
}

/// Representation of QuantumCircuit which the original circuit object + an
/// instance of `CircuitData`.
/// Representation of QuantumCircuit by using an instance of `CircuitData`.]
///
/// TODO: Remove this implementation once the `EquivalenceLibrary` is no longer
/// called from Python, or once the API is able to seamlessly accept instances
/// of `CircuitData`.
#[derive(Debug, Clone)]
pub struct CircuitRep(pub CircuitData);

Expand Down Expand Up @@ -383,7 +397,7 @@ impl EquivalenceLibrary {
gate: GateOper,
equivalent_circuit: CircuitRep,
) -> PyResult<()> {
self.add_equivalence(py, &gate, equivalent_circuit)
self.add_equivalence(py, &gate.operation, &gate.params, equivalent_circuit)
}

/// Check if a library contains any decompositions for gate.
Expand All @@ -395,7 +409,7 @@ impl EquivalenceLibrary {
/// Bool: True if gate has a known decomposition in the library.
/// False otherwise.
#[pyo3(name = "has_entry")]
pub fn py_has_entry(&self, gate: GateOper) -> bool {
fn py_has_entry(&self, gate: GateOper) -> bool {
self.has_entry(&gate.operation)
}

Expand All @@ -410,35 +424,9 @@ impl EquivalenceLibrary {
/// gate (Gate): A Gate instance.
/// entry (List['QuantumCircuit']) : A list of QuantumCircuits, each
/// equivalently implementing the given Gate.
fn set_entry(&mut self, py: Python, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
for equiv in entry.iter() {
raise_if_shape_mismatch(&gate, equiv)?;
raise_if_param_mismatch(py, &gate.params, equiv.0.unsorted_parameters(py)?)?;
}
let op_ref: OperationRef = gate.operation.view();
let key = Key {
name: op_ref.name().to_string(),
num_qubits: op_ref.num_qubits(),
};
let node_index = self.set_default_node(key);

if let Some(graph_ind) = self.graph.node_weight_mut(node_index) {
graph_ind.equivs.clear();
}

let edges: Vec<EdgeIndex> = self
.graph
.edges_directed(node_index, rustworkx_core::petgraph::Direction::Incoming)
.map(|x| x.id())
.collect();
for edge in edges {
self.graph.remove_edge(edge);
}
for equiv in entry {
self.add_equivalence(py, &gate, equiv)?
}
self._graph = None;
Ok(())
#[pyo3(name = "set_entry")]
fn py_set_entry(&mut self, py: Python, gate: GateOper, entry: Vec<CircuitRep>) -> PyResult<()> {
self.set_entry(py, &gate.operation, &gate.params, entry)
}

/// Gets the set of QuantumCircuits circuits from the library which
Expand All @@ -459,11 +447,7 @@ impl EquivalenceLibrary {
/// ordering of the StandardEquivalenceLibrary will not generally be
/// consistent across Qiskit versions.
fn get_entry(&self, py: Python, gate: GateOper) -> PyResult<Py<PyList>> {
let op_ref = gate.operation.view();
let key = Key {
name: op_ref.name().to_string(),
num_qubits: op_ref.num_qubits(),
};
let key = Key::from_operation(&gate.operation);
let query_params = gate.params;

let bound_equivalencies = self
Expand All @@ -477,6 +461,7 @@ impl EquivalenceLibrary {
Ok(return_list.unbind())
}

// TODO: Remove once BasisTranslator is in Rust.
#[getter]
fn get_graph(&mut self, py: Python) -> PyResult<PyObject> {
if let Some(graph) = &self._graph {
Expand Down Expand Up @@ -574,37 +559,34 @@ impl EquivalenceLibrary {

// Rust native methods
impl EquivalenceLibrary {
fn add_equivalence(
/// Add a new equivalence to the library. Future queries for the Gate
/// will include the given circuit, in addition to all existing equivalences
/// (including those from base).
pub fn add_equivalence(
&mut self,
py: Python,
gate: &GateOper,
gate: &PackedOperation,
params: &[Param],
equivalent_circuit: CircuitRep,
) -> PyResult<()> {
raise_if_shape_mismatch(gate, &equivalent_circuit)?;
raise_if_param_mismatch(
py,
&gate.params,
equivalent_circuit.0.unsorted_parameters(py)?,
)?;
let op_ref = gate.operation.view();
let key: Key = Key {
name: op_ref.name().to_string(),
num_qubits: op_ref.num_qubits(),
};
raise_if_param_mismatch(py, params, equivalent_circuit.0.unsorted_parameters(py)?)?;
let key: Key = Key::from_operation(gate);
let equiv = Equivalence {
params: gate.params.clone(),
circuit: equivalent_circuit.clone(),
params: params.into(),
};

let target = self.set_default_node(key);
if let Some(node) = self.graph.node_weight_mut(target) {
node.equivs.push(equiv.clone());
}
let sources: IndexSet<Key, RandomState> =
IndexSet::from_iter(equivalent_circuit.0.iter().map(|inst| Key {
name: inst.op.view().name().to_string(),
num_qubits: inst.op.view().num_qubits(),
}));
let sources: IndexSet<Key, RandomState> = IndexSet::from_iter(
equivalent_circuit
.0
.iter()
.map(|inst| Key::from_operation(&inst.op)),
);
let edges = Vec::from_iter(sources.iter().map(|source| {
(
self.set_default_node(source.clone()),
Expand All @@ -625,6 +607,41 @@ impl EquivalenceLibrary {
Ok(())
}

/// Set the equivalence record for a Gate. Future queries for the Gate
/// will return only the circuits provided.
pub fn set_entry(
&mut self,
py: Python,
gate: &PackedOperation,
params: &[Param],
entry: Vec<CircuitRep>,
) -> PyResult<()> {
for equiv in entry.iter() {
raise_if_shape_mismatch(gate, equiv)?;
raise_if_param_mismatch(py, params, equiv.0.unsorted_parameters(py)?)?;
}
let key = Key::from_operation(gate);
let node_index = self.set_default_node(key);

if let Some(graph_ind) = self.graph.node_weight_mut(node_index) {
graph_ind.equivs.clear();
}

let edges: Vec<EdgeIndex> = self
.graph
.edges_directed(node_index, rustworkx_core::petgraph::Direction::Incoming)
.map(|x| x.id())
.collect();
for edge in edges {
self.graph.remove_edge(edge);
}
for equiv in entry {
self.add_equivalence(py, gate, params, equiv)?
}
self._graph = None;
Ok(())
}

/// Rust native equivalent to `EquivalenceLibrary.has_entry()`
///
/// Check if a library contains any decompositions for gate.
Expand All @@ -636,10 +653,7 @@ impl EquivalenceLibrary {
/// `bool`: `true` if gate has a known decomposition in the library.
/// `false` otherwise.
pub fn has_entry(&self, operation: &PackedOperation) -> bool {
let key = Key {
name: operation.view().name().to_string(),
num_qubits: operation.view().num_qubits(),
};
let key = Key::from_operation(operation);
self.key_to_node_index.contains_key(&key)
}

Expand Down Expand Up @@ -695,8 +709,8 @@ fn raise_if_param_mismatch(
Ok(())
}

fn raise_if_shape_mismatch(gate: &GateOper, circuit: &CircuitRep) -> PyResult<()> {
let op_ref = gate.operation.view();
fn raise_if_shape_mismatch(gate: &PackedOperation, circuit: &CircuitRep) -> PyResult<()> {
let op_ref = gate.view();
if op_ref.num_qubits() != circuit.0.num_qubits() as u32
|| op_ref.num_clbits() != circuit.0.num_clbits() as u32
{
Expand Down

0 comments on commit b7c6316

Please sign in to comment.