Skip to content

Commit

Permalink
Fix: Keep track of Vars for add_from_iter
Browse files Browse the repository at this point in the history
- Make `from_iter` public.
  • Loading branch information
raynelfss committed Aug 21, 2024
1 parent edee939 commit 2600f5c
Showing 1 changed file with 41 additions and 6 deletions.
47 changes: 41 additions & 6 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6188,14 +6188,15 @@ impl DAGCircuit {
// Create HashSets to keep track of each bit/var's last node
let mut qubit_last_nodes: HashMap<Qubit, (NodeIndex, Wire)> = HashMap::default();
let mut clbit_last_nodes: HashMap<Clbit, (NodeIndex, Wire)> = HashMap::default();
// TODO: Keep track of vars
// TODO: Refactor once Vars are in rust
// Dict [ Var: (int, VarWeight)]
let vars_last_nodes: Bound<PyDict> = PyDict::new_bound(py);

// Store new nodes to return
let mut new_nodes = vec![];
for instr in iter {
let op_name = instr.op.name();
// TODO: Use _vars
let (all_cbits, _vars): (Vec<Clbit>, Option<Vec<PyObject>>) = {
let (all_cbits, vars): (Vec<Clbit>, Option<Vec<PyObject>>) = {
// Check if the clbits are already included
if self.may_have_additional_wires(py, &instr) {
let mut clbits: HashSet<Clbit> =
Expand Down Expand Up @@ -6266,29 +6267,63 @@ impl DAGCircuit {
nodes_to_connect.insert(clbit_last_node);
}

// TODO: Check all the vars in this instruction.
// If available, check all the vars in this instruction
if let Some(vars_available) = vars {
for var in vars_available {
let var_last_node = if vars_last_nodes.contains(&var)? {
let (node, wire): (usize, PyObject) =
vars_last_nodes.get_item(&var)?.unwrap().extract()?;
(NodeIndex::new(node), Wire::Var(wire))
} else {
let output_node = self.var_output_map.get(py, &var).unwrap();
let (edge_id, predecessor_node) = self
.dag
.edges_directed(output_node, Incoming)
.next()
.map(|edge| (edge.id(), (edge.source(), edge.weight().clone())))
.unwrap();
self.dag.remove_edge(edge_id);
predecessor_node
};

if let Wire::Var(var) = &var_last_node.1 {
vars_last_nodes.set_item(var, (new_node.index(), var))?
}
nodes_to_connect.insert(var_last_node);
}
}

// Add all of the new edges
for (node, wire) in nodes_to_connect {
self.dag.add_edge(node, new_node, wire);
}
}

// Add the output_nodes back
// Add the output_nodes back to qargs
for (qubit, (node, wire)) in qubit_last_nodes {
let output_node = self.qubit_io_map[&qubit][1];
self.dag.add_edge(node, output_node, wire);
}

// Add the output_nodes back to cargs
for (clbit, (node, wire)) in clbit_last_nodes {
let output_node = self.clbit_io_map[&clbit][1];
self.dag.add_edge(node, output_node, wire);
}

// Add the output_nodes back to vars
for item in vars_last_nodes.items() {
let (var, (node, wire)): (PyObject, (usize, PyObject)) = item.extract()?;
let output_node = self.var_output_map.get(py, &var).unwrap();
self.dag
.add_edge(NodeIndex::new(node), output_node, Wire::Var(wire));
}

Ok(new_nodes)
}

/// Creates an instance of DAGCircuit from an iterator over `PackedInstruction`.
fn from_iter<I>(_py: Python, _iter: I) -> PyResult<Self>
pub fn from_iter<I>(_py: Python, _iter: I) -> PyResult<Self>
where
I: IntoIterator<Item = PackedInstruction>,
{
Expand Down

0 comments on commit 2600f5c

Please sign in to comment.