Skip to content

Commit

Permalink
Fix: Use owned Strings.
Browse files Browse the repository at this point in the history
- Due to the nature of `hashbrown` we must use owned Strings instead of `&str`.
  • Loading branch information
raynelfss committed Sep 5, 2024
1 parent 778e99d commit 6e2f011
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions crates/accelerate/src/basis/basis_translator/basis_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ use std::cell::RefCell;

use hashbrown::{HashMap, HashSet};
use pyo3::prelude::*;

use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData};
use qiskit_circuit::operations::{Operation, Param};
use rustworkx_core::petgraph::stable_graph::{EdgeReference, NodeIndex, StableDiGraph};
use rustworkx_core::petgraph::visit::Control;
use rustworkx_core::traversal::{dijkstra_search, DijkstraEvent};
use smallvec::SmallVec;

use crate::equivalence::{CircuitRep, EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData};

#[pyfunction]
#[pyo3(name = "basis_search")]
/// Search for a set of transformations from source_basis to target_basis.
Expand Down Expand Up @@ -68,8 +68,9 @@ pub(crate) fn basis_search(
) -> Option<BasisTransforms> {
// Build the visitor attributes:
let mut num_gates_remaining_for_rule: HashMap<usize, usize> = HashMap::default();
let predecessors: RefCell<HashMap<(&str, u32), Equivalence>> = RefCell::new(HashMap::default());
let opt_cost_map: RefCell<HashMap<(&str, u32), u32>> = RefCell::new(HashMap::default());
let predecessors: RefCell<HashMap<(String, u32), Equivalence>> =
RefCell::new(HashMap::default());
let opt_cost_map: RefCell<HashMap<(String, u32), u32>> = RefCell::new(HashMap::default());
let mut basis_transforms: Vec<(String, u32, SmallVec<[Param; 3]>, CircuitRep)> = vec![];

// Initialize visitor attributes:
Expand Down Expand Up @@ -124,20 +125,26 @@ pub(crate) fn basis_search(
});

// Edge cost function for Visitor
let edge_weight =
|edge: EdgeReference<Option<EdgeData>>| -> Result<u32, ()> {
if edge.weight().is_none() {
return Ok(1);
}
let edge_data = edge.weight().as_ref().unwrap();
let mut cost_tot = 0;
let borrowed_cost = opt_cost_map_cell.borrow();
for instruction in edge_data.rule.circuit.0.iter() {
cost_tot += borrowed_cost[&(instruction.op.name(), instruction.op.num_qubits())];
}
Ok(cost_tot
- borrowed_cost[&(edge_data.source.name.as_str(), edge_data.source.num_qubits)])
};
let edge_weight = |edge: EdgeReference<Option<EdgeData>>| -> Result<u32, ()> {
if edge.weight().is_none() {
return Ok(1);
}
let edge_data = edge.weight().as_ref().unwrap();
let mut cost_tot = 0;
let borrowed_cost = opt_cost_map.borrow();
for instruction in edge_data.rule.circuit.0.iter() {
let instruction_op = instruction.op.view();
cost_tot += borrowed_cost[&(
instruction_op.name().to_string(),
instruction_op.num_qubits(),
)];
}
Ok(cost_tot
- borrowed_cost[&(
edge_data.source.name.to_string(),
edge_data.source.num_qubits,
)])
};

let basis_transforms = match dijkstra_search(
&equiv_lib.graph,
Expand All @@ -147,14 +154,15 @@ pub(crate) fn basis_search(
match event {
DijkstraEvent::Discover(n, score) => {
let gate_key = &equiv_lib.graph[n].key;
let gate = &(gate_key.name.as_str(), gate_key.num_qubits);
let gate = (gate_key.name.to_string(), gate_key.num_qubits);
source_basis_remain.remove(gate_key);
let mut borrowed_cost_map = opt_cost_map.borrow_mut();
borrowed_cost_map
.entry(*gate)
.and_modify(|cost_ref| *cost_ref = score)
.or_insert(score);
if let Some(rule) = predecessors.borrow().get(gate) {
if let Some(entry) = borrowed_cost_map.get_mut(&gate) {
*entry = score;
} else {
borrowed_cost_map.insert(gate.clone(), score);
}
if let Some(rule) = predecessors.borrow().get(&gate) {
// TODO: Logger
basis_transforms.push((
gate_key.name.to_string(),
Expand All @@ -173,7 +181,7 @@ pub(crate) fn basis_search(
let gate = &equiv_lib.graph[target].key;
predecessors
.borrow_mut()
.entry((gate.name.as_str(), gate.num_qubits))
.entry((gate.name.to_string(), gate.num_qubits))
.and_modify(|value| *value = edata.rule.clone())
.or_insert(edata.rule.clone());
}
Expand Down

0 comments on commit 6e2f011

Please sign in to comment.