diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f3e763feb..cce6ea97d 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,10 +2,11 @@ //! Dataflow analysis of Hugrs. mod datalog; +pub use datalog::Machine; mod value_row; mod machine; -pub use machine::{AnalysisResults, Machine, TailLoopTermination}; +pub use machine::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 143638ae2..0421567d0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -6,10 +6,10 @@ use std::hash::Hash; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; -use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; @@ -19,18 +19,53 @@ pub enum IO { Output, } -pub(super) struct DatalogResults { - pub in_wire_value: Vec<(Node, IncomingPort, PV)>, - pub out_wire_value: Vec<(Node, OutgoingPort, PV)>, - pub case_reachable: Vec<(Node, Node)>, - pub bb_reachable: Vec<(Node, Node)>, +/// Basic structure for performing an analysis. Usage: +/// 1. Get a new instance via [Self::default()] +/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values +/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via +/// [read_out_wire](AnalysisResults::read_out_wire) +pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); + +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl Default for Machine { + fn default() -> Self { + Self(Default::default()) + } } -pub(super) fn run_datalog( +impl Machine { + /// Provide initial values for some wires. + // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? + pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { + self.0.extend( + h.linked_inputs(wire.node(), wire.source()) + .map(|(n, inp)| (n, inp, value.clone())), + ); + } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. + /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, + /// but should handle other containers.) + /// The context passed in allows interpretation of leaf operations. + pub fn run( + mut self, + context: &impl DFContext, + hugr: H, + in_values: impl IntoIterator)>, + ) -> AnalysisResults { + let root = hugr.root(); + self.0 + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + run_datalog(self.0, context, hugr) + } +} + +pub(super) fn run_datalog( in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, c: &impl DFContext, - hugr: &impl HugrView, -) -> DatalogResults { + hugr: H, +) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. #![allow( @@ -84,7 +119,7 @@ pub(super) fn run_datalog( if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(c, &hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -211,9 +246,15 @@ pub(super) fn run_datalog( io_node(func, outp, IO::Output), in_wire_value(outp, p, v); }; - DatalogResults { + let out_wire_values = all_results + .out_wire_value + .iter() + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults { + hugr, + out_wire_values, in_wire_value: all_results.in_wire_value, - out_wire_value: all_results.out_wire_value, case_reachable: all_results.case_reachable, bb_reachable: all_results.bb_reachable, } diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/results.rs similarity index 65% rename from hugr-passes/src/dataflow/machine.rs rename to hugr-passes/src/dataflow/results.rs index f7f97463d..0be8072b0 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,66 +2,15 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::datalog::{run_datalog, DatalogResults}; -use super::{AbstractValue, DFContext, PartialValue}; - -/// Basic structure for performing an analysis. Usage: -/// 1. Get a new instance via [Self::default()] -/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) -pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); +use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with context for easy inspection pub struct AnalysisResults { - hugr: H, - results: DatalogResults, - out_wire_values: HashMap>, -} - -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl Default for Machine { - fn default() -> Self { - Self(Default::default()) - } -} - -impl Machine { - /// Provide initial values for some wires. - // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? - pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - self.0.extend( - h.linked_inputs(wire.node(), wire.source()) - .map(|(n, inp)| (n, inp, value.clone())), - ); - } - - /// Run the analysis (iterate until a lattice fixpoint is reached), - /// given initial values for some of the root node inputs. - /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, - /// but should handle other containers.) - /// The context passed in allows interpretation of leaf operations. - pub fn run( - mut self, - context: &impl DFContext, - hugr: H, - in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = hugr.root(); - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let results = run_datalog(self.0, context, &hugr); - let out_wire_values = results - .out_wire_value - .iter() - .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(); - AnalysisResults { - hugr, - results, - out_wire_values, - } - } + pub(super) hugr: H, + pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, + pub(super) case_reachable: Vec<(Node, Node)>, + pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) out_wire_values: HashMap>, } impl AnalysisResults { @@ -79,8 +28,7 @@ impl AnalysisResults { self.hugr.get_optype(node).as_tail_loop()?; let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( - self.results - .in_wire_value + self.in_wire_value .iter() .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), @@ -99,8 +47,7 @@ impl AnalysisResults { let cond = self.hugr.get_parent(case)?; self.hugr.get_optype(cond).as_conditional()?; Some( - self.results - .case_reachable + self.case_reachable .iter() .any(|(cond2, case2)| &cond == cond2 && &case == case2), ) @@ -120,8 +67,7 @@ impl AnalysisResults { return None; }; Some( - self.results - .bb_reachable + self.bb_reachable .iter() .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), )