diff --git a/crates/flowistry/src/pdg/construct.rs b/crates/flowistry/src/pdg/construct.rs index f59f154c5..e60f6c0e1 100644 --- a/crates/flowistry/src/pdg/construct.rs +++ b/crates/flowistry/src/pdg/construct.rs @@ -632,30 +632,20 @@ impl<'tcx> GraphConstructor<'tcx> { } } - enum CallKind { - /// A standard function call like `f(x)`. - Direct, - /// A call to a function variable, like `fn foo(f: impl Fn()) { f() }` - Indirect, - /// A poll to an async function, like `f.await`. - AsyncPoll, - } - // Determine the type of call-site. - let (call_kind, args) = match tcx.def_path_str(called_def_id).as_str() { - "std::ops::Fn::call" => (CallKind::Indirect, args), - "std::future::Future::poll" => { - let args = self.find_async_args(args)?; - (CallKind::AsyncPoll, args) - } - def_path => { - if resolved_def_id.is_local() { - (CallKind::Direct, args) - } else { - trace!(" Bailing because func is non-local: `{def_path}`"); - return None; - } - } + let Some(call_kind) = CallKind::classify(tcx, called_def_id) else { + trace!( + " Bailing because func is non-local: `{}`", + tcx.def_path_str(called_def_id) + ); + return None; + }; + + let args = if call_kind == CallKind::AsyncPoll { + self.find_async_args(args)? + } else { + args }; + trace!(" Handling call!"); // A helper to translate an argument (or return) in the child into a place in the parent. @@ -1011,6 +1001,35 @@ impl<'tcx> GraphConstructor<'tcx> { } } +#[derive(PartialEq, Eq)] +enum CallKind { + /// A standard function call like `f(x)`. + Direct, + /// A call to a function variable, like `fn foo(f: impl Fn()) { f() }` + Indirect, + /// A poll to an async function, like `f.await`. + AsyncPoll, +} + +impl CallKind { + /// Determine the type of call-site. + fn classify(tcx: TyCtxt, def_id: DefId) -> Option { + let lang_items = tcx.lang_items(); + if lang_items.future_poll_fn() == Some(def_id) { + return Some(Self::AsyncPoll); + } + let my_impl = tcx.impl_of_method(def_id)?; + let my_trait = tcx.trait_id_of_impl(my_impl)?; + if Some(my_trait) == lang_items.fn_trait() + || Some(my_trait) == lang_items.fn_mut_trait() + || Some(my_trait) == lang_items.fn_once_trait() + { + return Some(Self::Indirect); + } + def_id.is_local().then_some(Self::Direct) + } +} + struct DfAnalysis<'a, 'tcx>(&'a GraphConstructor<'tcx>); impl<'tcx> df::AnalysisDomain<'tcx> for DfAnalysis<'_, 'tcx> {