Skip to content

Commit

Permalink
Improve ML search performance
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasAlaif committed Oct 31, 2024
1 parent e687a1c commit e4a35a9
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 149 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
Cargo.lock
.DS_STORE
/logs
*.log
34 changes: 17 additions & 17 deletions axiom-profiler-GUI/src/results/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use petgraph::{
use smt_log_parser::{
analysis::{
analysis::matching_loop::MLGraphNode,
raw::{Node, NodeKind, RawInstGraph},
raw::{IndexesInstGraph, Node, NodeKind, RawInstGraph},
InstGraph, RawNodeIndex,
},
display_with::{DisplayCtxt, DisplayWithCtxt},
Expand Down Expand Up @@ -151,26 +151,26 @@ impl Filter {
}
})
.collect::<fxhash::FxHashSet<_>>();
let relevant_non_qi_nodes: Vec<_> =
Dfs::new(&*graph.raw.graph, nth_ml_endnode.1[0].1 .0)
.iter(graph.raw.rev())
.filter(|nx| graph.raw.graph[*nx].kind().inst().is_none())
.filter(|nx| {
graph.raw.graph[*nx]
.inst_children
let start = nth_ml_endnode.1[0].1.index(&graph.raw).0;
let relevant_non_qi_nodes: Vec<_> = Dfs::new(&*graph.raw.graph, start)
.iter(graph.raw.rev())
.filter(|nx| graph.raw.graph[*nx].kind().inst().is_none())
.filter(|nx| {
graph.raw.graph[*nx]
.inst_children
.nodes
.intersection(&nodes_of_nth_matching_loop)
.count()
> 0
&& graph.raw.graph[*nx]
.inst_parents
.nodes
.intersection(&nodes_of_nth_matching_loop)
.count()
> 0
&& graph.raw.graph[*nx]
.inst_parents
.nodes
.intersection(&nodes_of_nth_matching_loop)
.count()
> 0
})
.map(RawNodeIndex)
.collect();
})
.map(RawNodeIndex)
.collect();
graph
.raw
.set_visibility_many(false, relevant_non_qi_nodes.into_iter());
Expand Down
137 changes: 137 additions & 0 deletions smt-log-parser/src/analysis/graph/analysis/matching_loop/analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use fxhash::FxHashSet;

use crate::{
analysis::{
analysis::{MlEndNodes, TopoAnalysis},
raw::Node,
InstGraph, RawNodeIndex,
},
items::InstIdx,
FxHashMap, Z3Parser,
};

use super::MlSignature;

pub struct MlAnalysis {
pub data: Vec<MlSignature>,
pub node_to_ml: FxHashMap<InstIdx, MlNodeInfo>,
}

impl MlAnalysis {
pub fn new(parser: &Z3Parser, signatures: Vec<(MlSignature, FxHashSet<InstIdx>)>) -> Self {
let mut node_to_ml = FxHashMap::<InstIdx, MlNodeInfo>::default();
let data = signatures
.into_iter()
.enumerate()
.map(|(i, (sig, iidxs))| {
node_to_ml.extend(
iidxs
.into_iter()
.map(|iidx| (iidx, MlNodeInfo::new(parser, iidx, i))),
);
sig
})
.collect();
Self { data, node_to_ml }
}

pub fn finalise(self, min_depth: u32) -> (MlEndNodes, FxHashSet<InstIdx>) {
let mut ml_end_nodes: MlEndNodes =
self.data.into_iter().map(|sig| (sig, Vec::new())).collect();
let mut ml_nodes = FxHashSet::default();
for (iidx, data) in self.node_to_ml.iter() {
if !data.is_root || data.max_depth < min_depth {
continue;
}
ml_nodes.insert(*iidx);
Self::walk_tree(&self.node_to_ml, data, &mut ml_nodes);
ml_end_nodes[data.ml_sig].1.push((data.max_depth, *iidx));
}
ml_end_nodes.retain_mut(|(_, v)| {
if v.is_empty() {
false
} else {
v.sort_unstable_by_key(|(len, idx)| (u32::MAX - *len, *idx));
true
}
});
(ml_end_nodes, ml_nodes)
}

pub fn walk_tree(
node_to_ml: &FxHashMap<InstIdx, MlNodeInfo>,
data: &MlNodeInfo,
ml_nodes: &mut FxHashSet<InstIdx>,
) {
for &reachable in &data.tree_above {
if ml_nodes.insert(reachable) {
let data = &node_to_ml[&reachable];
Self::walk_tree(node_to_ml, data, ml_nodes);
}
}
}
}

#[derive(Clone, Debug)]
pub struct MlNodeInfo {
pub is_root: bool,
pub ml_sig: usize,
pub max_depth: u32,
pub ast_size: u32,
pub tree_above: FxHashSet<InstIdx>,
}

impl MlNodeInfo {
pub fn new(parser: &Z3Parser, iidx: InstIdx, ml_sig: usize) -> Self {
Self {
is_root: true,
ml_sig,
max_depth: 0,
ast_size: parser.inst_ast_size(iidx),
tree_above: FxHashSet::default(),
}
}
}

impl TopoAnalysis<true, false> for MlAnalysis {
type Value = FxHashSet<InstIdx>;

fn collect<'a, 'n, T: Iterator<Item = (RawNodeIndex, &'n Self::Value)>>(
&mut self,
graph: &'a InstGraph,
idx: RawNodeIndex,
_node: &'a Node,
from_all: impl Fn() -> T,
) -> Self::Value
where
Self::Value: 'n,
{
let mut self_info = FxHashSet::default();
for (_, info) in from_all() {
self_info.extend(info.iter().copied());
}

let Some(iidx) = graph.raw[idx].kind().inst() else {
return self_info;
};
let Some(mut curr_info) = self.node_to_ml.remove(&iidx) else {
return self_info;
};

self_info.retain(|&prev_iidx| {
let prev_info = self.node_to_ml.get_mut(&prev_iidx).unwrap();
if prev_info.ml_sig == curr_info.ml_sig && prev_info.ast_size <= curr_info.ast_size {
prev_info.is_root = false;
curr_info.max_depth = curr_info.max_depth.max(prev_info.max_depth + 1);
curr_info.tree_above.insert(prev_iidx);
false
} else {
true
}
});

self.node_to_ml.insert(iidx, curr_info);
self_info.insert(iidx);
self_info
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod analysis;
mod generalise;
mod node;
mod search;
mod signature;

pub use analysis::*;
pub use node::*;
pub use signature::*;

Expand Down
44 changes: 16 additions & 28 deletions smt-log-parser/src/analysis/graph/analysis/matching_loop/search.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use fxhash::FxHashSet;
use petgraph::visit::{Dfs, Reversed};

use crate::{
analysis::{
analysis::MlEndNodes, raw::NodeKind, visible::VisibleEdge, InstGraph, RawNodeIndex,
analysis::{matching_loop::MlAnalysis, MlEndNodes},
raw::{IndexesInstGraph, NodeKind},
visible::VisibleEdge,
InstGraph,
},
items::{GraphIdx, InstIdx, TermIdx, TimeRange},
Z3Parser,
Expand All @@ -29,7 +33,11 @@ impl InstGraph {

let long_path_leaves_sub_idx: Vec<_> = long_path_leaves
.iter()
.map(|(_, leaves)| long_paths_subgraph.reverse(leaves[0].1).unwrap())
.map(|(_, leaves)| {
long_paths_subgraph
.reverse(leaves[0].1.index(&self.raw))
.unwrap()
})
.collect();

// assign to each node in a matching loop which matching loops it belongs to, i.e., if a node is part of the
Expand Down Expand Up @@ -93,8 +101,8 @@ impl InstGraph {
/// Per each quantifier, finds the nodes that are part paths of length at
/// least `MIN_MATCHING_LOOP_LENGTH`. Additionally, returns a list of the
/// endpoints of these paths.
fn find_long_paths_per_quant(&mut self, parser: &Z3Parser) -> (MlEndNodes, Vec<RawNodeIndex>) {
let signatures = Self::collect_ml_signatures(parser);
fn find_long_paths_per_quant(&mut self, parser: &Z3Parser) -> (MlEndNodes, FxHashSet<InstIdx>) {
let signatures = self.collect_ml_signatures(parser);
// Collect all signatures instantiated at least `MIN_MATCHING_LOOP_LENGTH` times
let mut signatures: Vec<_> = signatures
.into_iter()
Expand All @@ -107,32 +115,12 @@ impl InstGraph {
.then_some((sig, insts))
})
.collect();
signatures.sort_unstable_by(|a, b| a.0.cmp(&b.0));
// eprintln!("Found {} signatures", signatures.len());
signatures.sort_unstable_by(|a, b| a.0.cmp(&b.0));

let mut long_path_leaves = Vec::new();
let mut long_path_nodes = Vec::new();
for (sig, insts) in signatures {
// eprintln!("Checking signature: {}", sig.to_string(parser));
self.raw.reset_visibility_to(true);
let to_raw = self.raw.inst_to_raw_idx();
self.raw
.set_visibility_many(false, insts.iter().copied().map(to_raw));
let mut single_quant_subgraph = self.to_visible_opt();

let max_depths =
single_quant_subgraph.compute_longest_distances_from_roots(self, parser);
let (leaves, nodes) = single_quant_subgraph
.collect_nodes_in_long_paths(&max_depths, MIN_MATCHING_LOOP_LENGTH);
let mut leaves: Vec<_> = leaves.collect();
if leaves.is_empty() {
continue;
}
leaves.sort_unstable_by_key(|(len, idx)| (u32::MAX - *len, *idx));
long_path_leaves.push((sig, leaves));
long_path_nodes.extend(nodes);
}
(long_path_leaves, long_path_nodes)
let mut analysis = MlAnalysis::new(parser, signatures);
self.topo_analysis(&mut analysis);
analysis.finalise(MIN_MATCHING_LOOP_LENGTH)
}

pub fn found_matching_loops(&self) -> Option<usize> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
analysis::InstGraph,
display_with::{DisplayCtxt, DisplayWithCtxt},
formatter::TermDisplayContext,
items::{ENodeIdx, InstIdx, Instantiation, QuantIdx, TermIdx},
items::{ENodeIdx, GraphIdx, InstIdx, QuantIdx, TermIdx},
FxHashMap, Z3Parser,
};

Expand All @@ -17,6 +17,7 @@ pub struct MlSignature {
pub quantifier: QuantIdx,
pub pattern: TermIdx,
pub parents: Box<[InstParent]>,
pub subgraph: GraphIdx,
}

/// For each pattern in the matched pattern, where did the blamed term come
Expand All @@ -33,8 +34,10 @@ pub enum InstParent {
}

impl MlSignature {
pub fn new(parser: &Z3Parser, inst: &Instantiation) -> Option<Self> {
let match_ = &parser[inst.match_];
pub fn new(graph: &InstGraph, parser: &Z3Parser, inst: InstIdx) -> Option<Self> {
let subgraph = graph.raw[inst].subgraph?.0;

let match_ = &parser[parser[inst].match_];
let pattern = match_.kind.pattern()?;
// If it has a pattern then definitely also has a quant_idx
let quant_idx = match_.kind.quant_idx().unwrap();
Expand All @@ -56,6 +59,7 @@ impl MlSignature {
quantifier: quant_idx,
pattern,
parents,
subgraph,
})
}

Expand All @@ -76,20 +80,22 @@ impl MlSignature {
.collect::<Vec<_>>()
.join(", ");
format!(
"{} {} {parents:?}",
"{} {} {parents:?} {:?}",
self.quantifier.with(&ctxt),
self.pattern,
self.subgraph,
)
}
}

impl InstGraph {
pub(super) fn collect_ml_signatures(
&self,
parser: &Z3Parser,
) -> FxHashMap<MlSignature, FxHashSet<InstIdx>> {
let mut signatures = FxHashMap::<_, FxHashSet<_>>::default();
for (iidx, inst) in parser.instantiations().iter_enumerated() {
let Some(ml_sig) = MlSignature::new(parser, inst) else {
for (iidx, _) in parser.instantiations().iter_enumerated() {
let Some(ml_sig) = MlSignature::new(self, parser, iidx) else {
continue;
};
signatures.entry(ml_sig).or_default().insert(iidx);
Expand Down
Loading

0 comments on commit e4a35a9

Please sign in to comment.