Skip to content

Commit

Permalink
Intern candidates and results in proof tree.
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinleroy committed Feb 26, 2024
1 parent 5f74ec2 commit cda0077
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 173 deletions.
220 changes: 220 additions & 0 deletions crates/argus/src/proof_tree/interners.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
use std::{
cmp::{Eq, PartialEq},
hash::Hash,
};

use index_vec::{Idx, IndexVec};
use rustc_data_structures::{fx::FxHashMap as HashMap, stable_hasher::Hash64};
use rustc_hir::def_id::DefId;
use rustc_infer::infer::InferCtxt;
use rustc_trait_selection::{
solve::inspect::{InspectCandidate, InspectGoal},
traits::{
solve,
solve::{inspect::ProbeKind, CandidateSource},
},
};

use super::*;
use crate::{ext::TyCtxtExt, types::intermediate::EvaluationResult};

#[derive(Default)]
struct Interner<K: PartialEq + Eq + Hash, I: Idx, D> {
values: IndexVec<I, D>,
keys: HashMap<K, I>,
}

impl<K, I, D> Interner<K, I, D>
where
K: PartialEq + Eq + Hash,
I: Idx,
{
fn default() -> Self {
Self {
values: IndexVec::default(),
keys: HashMap::default(),
}
}

fn get(&mut self, key: K) -> Option<I> {
self.keys.get(&key).cloned()
}

fn insert(&mut self, k: K, d: D) -> I {
let idx = self.values.push(d);
self.keys.insert(k, idx);
idx
}

fn insert_no_key(&mut self, d: D) -> I {
self.values.push(d)
}
}

pub struct Interners {
goals: Interner<(Hash64, ResultIdx), GoalIdx, GoalData>,
candidates: Interner<CanKey, CandidateIdx, CandidateData>,
results: Interner<EvaluationResult, ResultIdx, ResultData>,
}

#[derive(PartialEq, Eq, Hash)]
enum CanKey {
Impl(DefId),
ParamEnv(usize),
Str(&'static str),
}

impl Interners {
pub fn default() -> Self {
Self {
goals: Interner::default(),
candidates: Interner::default(),
results: Interner::default(),
}
}

pub fn take(
self,
) -> (
IndexVec<GoalIdx, GoalData>,
IndexVec<CandidateIdx, CandidateData>,
IndexVec<ResultIdx, ResultData>,
) {
(
self.goals.values,
self.candidates.values,
self.results.values,
)
}

pub fn mk_result_node(&mut self, result: EvaluationResult) -> Node {
Node::Result(self.intern_result(result))
}

pub fn mk_goal_node<'tcx>(&mut self, goal: &InspectGoal<'_, 'tcx>) -> Node {
let infcx = goal.infcx();
let result_idx = self.intern_result(goal.result());
let goal = goal.goal();
let goal_idx = self.intern_goal(infcx, &goal, result_idx);
Node::Goal(goal_idx, result_idx)
}

pub fn mk_candidate_node<'tcx>(
&mut self,
candidate: &InspectCandidate<'_, 'tcx>,
) -> Node {
let can_idx = match candidate.kind() {
ProbeKind::Root { .. } => self.intern_can_string("root"),
ProbeKind::NormalizedSelfTyAssembly => {
self.intern_can_string("normalized-self-ty-asm")
}
ProbeKind::UnsizeAssembly => self.intern_can_string("unsize-asm"),
ProbeKind::CommitIfOk => self.intern_can_string("commit-if-ok"),
ProbeKind::UpcastProjectionCompatibility => {
self.intern_can_string("upcase-proj-compat")
}
ProbeKind::MiscCandidate { .. } => self.intern_can_string("misc"),
ProbeKind::TraitCandidate { source, .. } => match source {
CandidateSource::BuiltinImpl(_built_impl) => {
self.intern_can_string("builtin")
}
CandidateSource::AliasBound => self.intern_can_string("alias-bound"),
// The only two we really care about.
CandidateSource::ParamEnv(idx) => self.intern_can_param_env(idx),

CandidateSource::Impl(def_id) => {
self.intern_impl(candidate.infcx(), def_id)
}
},
};

Node::Candidate(can_idx)
}

fn intern_result(&mut self, result: EvaluationResult) -> ResultIdx {
if let Some(result_idx) = self.results.get(result) {
return result_idx;
}

self.results.insert(result, ResultData(result))
}

fn intern_goal<'tcx>(
&mut self,
infcx: &InferCtxt<'tcx>,
goal: &solve::Goal<'tcx, ty::Predicate<'tcx>>,
result_idx: ResultIdx,
) -> GoalIdx {
let goal = infcx.resolve_vars_if_possible(*goal);
let hash = infcx.predicate_hash(&goal.predicate);
let hash = (hash, result_idx);
if let Some(goal_idx) = self.goals.get(hash) {
return goal_idx;
}

let necessity = infcx.guess_predicate_necessity(&goal.predicate);
let num_vars =
serialize::var_counter::count_vars(infcx.tcx, goal.predicate);
let is_lhs_ty_var = goal.predicate.is_lhs_ty_var();
let goal_value = serialize_to_value(infcx, &GoalPredicateDef(goal))
.expect("failed to serialize goal");

self.goals.insert(hash, GoalData {
value: goal_value,
necessity,
num_vars,
is_lhs_ty_var,

#[cfg(debug_assertions)]
debug_comparison: format!("{:?}", goal.predicate.kind().skip_binder()),
})
}

fn intern_can_string(&mut self, s: &'static str) -> CandidateIdx {
if let Some(i) = self.candidates.get(CanKey::Str(s)) {
return i;
}

self.candidates.insert(CanKey::Str(s), s.into())
}

fn intern_can_param_env(&mut self, idx: usize) -> CandidateIdx {
if let Some(i) = self.candidates.get(CanKey::ParamEnv(idx)) {
return i;
}

self
.candidates
.insert(CanKey::ParamEnv(idx), CandidateData::ParamEnv(idx))
}

fn intern_impl(&mut self, infcx: &InferCtxt, def_id: DefId) -> CandidateIdx {
if let Some(i) = self.candidates.get(CanKey::Impl(def_id)) {
return i;
}

let tcx = infcx.tcx;

// First, try to get an impl header from the def_id ty
if let Some(header) = tcx.get_impl_header(def_id) {
return self.candidates.insert(
CanKey::Impl(def_id),
CandidateData::new_impl_header(infcx, &header),
);
}

// Second, try to get the span of the impl or just default to a fallback.
let string = tcx
.span_of_impl(def_id)
.map(|sp| {
tcx
.sess
.source_map()
.span_to_snippet(sp)
.unwrap_or_else(|_| "failed to find impl".to_string())
})
.unwrap_or_else(|symb| format!("foreign impl from: {}", symb.as_str()));

self.candidates.insert_no_key(CandidateData::from(string))
}
}
57 changes: 28 additions & 29 deletions crates/argus/src/proof_tree/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Proof tree types sent to the Argus frontend.
pub mod ext;
pub(self) mod interners;
pub(super) mod serialize;
pub mod topology;

Expand All @@ -26,7 +27,9 @@ use crate::{
crate::define_idx! {
usize,
ProofNodeIdx,
GoalIdx
GoalIdx,
CandidateIdx,
ResultIdx
}

// FIXME: Nodes shouldn't be PartialEq, or Eq. They are currently
Expand All @@ -38,29 +41,15 @@ crate::define_idx! {
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub enum Node {
Result(
#[serde(with = "EvaluationResultDef")]
#[cfg_attr(feature = "testing", ts(type = "EvaluationResult"))]
EvaluationResult,
),
Candidate(Candidate),
Goal(
GoalIdx,
#[serde(with = "EvaluationResultDef")]
#[cfg_attr(feature = "testing", ts(type = "EvaluationResult"))]
EvaluationResult,
),
Result(ResultIdx),
Candidate(CandidateIdx),
Goal(GoalIdx, ResultIdx),
}

#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub struct Goal {}

#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub enum Candidate {
pub enum CandidateData {
Impl(
#[cfg_attr(feature = "testing", ts(type = "ImplHeader"))] serde_json::Value,
),
Expand All @@ -69,6 +58,15 @@ pub enum Candidate {
Any(String),
}

#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub struct ResultData(
#[serde(with = "EvaluationResultDef")]
#[cfg_attr(feature = "testing", ts(type = "EvaluationResult"))]
EvaluationResult,
);

#[derive(Serialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "testing", derive(TS))]
Expand All @@ -95,9 +93,16 @@ pub struct SerializedTree {

#[cfg_attr(feature = "testing", ts(type = "Node[]"))]
pub nodes: IndexVec<ProofNodeIdx, Node>,

#[cfg_attr(feature = "testing", ts(type = "GoalData[]"))]
pub goals: IndexVec<GoalIdx, GoalData>,

#[cfg_attr(feature = "testing", ts(type = "CandidateData[]"))]
pub candidates: IndexVec<CandidateIdx, CandidateData>,

#[cfg_attr(feature = "testing", ts(type = "ResultData[]"))]
pub results: IndexVec<ResultIdx, ResultData>,

pub topology: TreeTopology,
pub error_leaves: Vec<ProofNodeIdx>,
pub unnecessary_roots: HashSet<ProofNodeIdx>,
Expand All @@ -113,31 +118,25 @@ pub struct ProofCycle(Vec<ProofNodeIdx>);
// ----------------------------------------
// impls

impl Candidate {
impl CandidateData {
fn new_impl_header<'tcx>(
infcx: &InferCtxt<'tcx>,
impl_: &ImplHeader<'tcx>,
) -> Self {
let impl_ =
serialize_to_value(infcx, impl_).expect("couldn't serialize impl header");

Self::Impl(impl_)
}

// TODO: we should pass the ParamEnv here for certainty.
fn new_param_env(idx: usize) -> Self {
Self::ParamEnv(idx)
}
}

impl From<&'static str> for Candidate {
impl From<&'static str> for CandidateData {
fn from(value: &'static str) -> Self {
value.to_string().into()
}
}

impl From<String> for Candidate {
impl From<String> for CandidateData {
fn from(value: String) -> Self {
Candidate::Any(value)
Self::Any(value)
}
}
Loading

0 comments on commit cda0077

Please sign in to comment.