Skip to content

Commit

Permalink
Determine call type based on LanguageItems
Browse files Browse the repository at this point in the history
Also handles `FnMut` and `FnOnce` now
  • Loading branch information
JustusAdam committed Feb 8, 2024
1 parent 3b0a126 commit ecbc4ca
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions crates/flowistry/src/pdg/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,30 +632,20 @@ impl<'tcx> GraphConstructor<'tcx> {
}
}

enum CallKind {
/// A standard function call like `f(x)`.
Direct,
/// A call to a function variable, like `fn foo(f: impl Fn()) { f() }`
Indirect,
/// A poll to an async function, like `f.await`.
AsyncPoll,
}
// Determine the type of call-site.
let (call_kind, args) = match tcx.def_path_str(called_def_id).as_str() {
"std::ops::Fn::call" => (CallKind::Indirect, args),
"std::future::Future::poll" => {
let args = self.find_async_args(args)?;
(CallKind::AsyncPoll, args)
}
def_path => {
if resolved_def_id.is_local() {
(CallKind::Direct, args)
} else {
trace!(" Bailing because func is non-local: `{def_path}`");
return None;
}
}
let Some(call_kind) = CallKind::classify(tcx, called_def_id) else {
trace!(
" Bailing because func is non-local: `{}`",
tcx.def_path_str(called_def_id)
);
return None;
};

let args = if call_kind == CallKind::AsyncPoll {
self.find_async_args(args)?
} else {
args
};

trace!(" Handling call!");

// A helper to translate an argument (or return) in the child into a place in the parent.
Expand Down Expand Up @@ -1011,6 +1001,35 @@ impl<'tcx> GraphConstructor<'tcx> {
}
}

#[derive(PartialEq, Eq)]
enum CallKind {
/// A standard function call like `f(x)`.
Direct,
/// A call to a function variable, like `fn foo(f: impl Fn()) { f() }`
Indirect,
/// A poll to an async function, like `f.await`.
AsyncPoll,
}

impl CallKind {
/// Determine the type of call-site.
fn classify(tcx: TyCtxt, def_id: DefId) -> Option<Self> {
let lang_items = tcx.lang_items();
if lang_items.future_poll_fn() == Some(def_id) {
return Some(Self::AsyncPoll);
}
let my_impl = tcx.impl_of_method(def_id)?;
let my_trait = tcx.trait_id_of_impl(my_impl)?;
if Some(my_trait) == lang_items.fn_trait()
|| Some(my_trait) == lang_items.fn_mut_trait()
|| Some(my_trait) == lang_items.fn_once_trait()
{
return Some(Self::Indirect);
}
def_id.is_local().then_some(Self::Direct)
}
}

struct DfAnalysis<'a, 'tcx>(&'a GraphConstructor<'tcx>);

impl<'tcx> df::AnalysisDomain<'tcx> for DfAnalysis<'_, 'tcx> {
Expand Down

0 comments on commit ecbc4ca

Please sign in to comment.