diff --git a/crates/flowistry/src/pdg/construct.rs b/crates/flowistry/src/pdg/construct.rs index 58bbb71dc..6d6d3813f 100644 --- a/crates/flowistry/src/pdg/construct.rs +++ b/crates/flowistry/src/pdg/construct.rs @@ -6,7 +6,7 @@ use flowistry_pdg::{CallString, GlobalLocation, RichLocation}; use itertools::Itertools; use log::{debug, trace}; use petgraph::graph::DiGraph; -use rustc_abi::FieldIdx; +use rustc_abi::{FieldIdx, VariantIdx}; use rustc_borrowck::consumers::{ places_conflict, BodyWithBorrowckFacts, PlaceConflictBias, }; @@ -21,6 +21,7 @@ use rustc_middle::{ ty::{GenericArg, GenericArgsRef, List, ParamEnv, TyCtxt, TyKind}, }; use rustc_mir_dataflow::{self as df}; +use rustc_span::sym::poll; use rustc_utils::{ mir::{borrowck_facts, control_dependencies::ControlDependencies}, BodyExt, PlaceExt, @@ -201,6 +202,25 @@ struct CallingContext<'tcx> { call_stack: Vec, } +/// Stores ids that are needed to construct projections around async functions. +struct AsyncInfo { + poll_ready_variant_idx: VariantIdx, + poll_ready_field_idx: FieldIdx, +} + +impl AsyncInfo { + fn make(tcx: TyCtxt) -> Option> { + let lang_items = tcx.lang_items(); + let poll_def = tcx.adt_def(lang_items.poll()?); + let ready_vid = lang_items.poll_ready_variant()?; + assert_eq!(poll_def.variant_with_id(ready_vid).fields.len(), 1); + Some(Rc::new(Self { + poll_ready_variant_idx: poll_def.variant_index_with_id(ready_vid), + poll_ready_field_idx: 0_u32.into(), + })) + } +} + pub struct GraphConstructor<'tcx> { tcx: TyCtxt<'tcx>, params: PdgParams<'tcx>, @@ -212,6 +232,7 @@ pub struct GraphConstructor<'tcx> { body_assignments: utils::BodyAssignments, calling_context: Option>, start_loc: FxHashSet, + async_info: Rc, } macro_rules! trylet { @@ -226,11 +247,20 @@ macro_rules! trylet { impl<'tcx> GraphConstructor<'tcx> { /// Creates a [`GraphConstructor`] at the root of the PDG. pub fn root(params: PdgParams<'tcx>) -> Self { - GraphConstructor::new(params, None) + let tcx = params.tcx; + GraphConstructor::new( + params, + None, + AsyncInfo::make(tcx).expect("async functions are not defined"), + ) } /// Creates [`GraphConstructor`] for a function resolved as `fn_resolution` in a given `calling_context`. - fn new(params: PdgParams<'tcx>, calling_context: Option>) -> Self { + fn new( + params: PdgParams<'tcx>, + calling_context: Option>, + async_info: Rc, + ) -> Self { let tcx = params.tcx; let def_id = params.root.def_id().expect_local(); let body_with_facts = borrowck_facts::get_body_with_borrowck_facts(tcx, def_id); @@ -268,6 +298,7 @@ impl<'tcx> GraphConstructor<'tcx> { def_id, calling_context, body_assignments, + async_info, } } @@ -672,9 +703,12 @@ impl<'tcx> GraphConstructor<'tcx> { let (parent_place, child_projection) = match call_kind { // Async return must be handled special, because it gets wrapped in `Poll::Ready` CallKind::AsyncPoll if child.local == RETURN_PLACE => { - let in_poll = - destination.project_deeper(&[PlaceElem::Downcast(None, 0_u32.into())], tcx); - let field_idx = 0_u32.into(); + let async_info = self.async_info.as_ref(); + let in_poll = destination.project_deeper( + &[PlaceElem::Downcast(None, async_info.poll_ready_variant_idx)], + tcx, + ); + let field_idx = async_info.poll_ready_field_idx; let child_inner_return_type = in_poll .ty(parent_body.local_decls(), tcx) .field_ty(tcx, field_idx); @@ -750,7 +784,8 @@ impl<'tcx> GraphConstructor<'tcx> { param_env, call_stack, }; - let child_constructor = GraphConstructor::new(params, Some(calling_context)); + let child_constructor = + GraphConstructor::new(params, Some(calling_context), self.async_info.clone()); if let Some(callback) = &self.params.call_change_callback { let info = CallInfo { @@ -961,7 +996,12 @@ impl<'tcx> GraphConstructor<'tcx> { call_string, call_stack, }; - return GraphConstructor::new(params, Some(calling_context)).construct_partial(); + return GraphConstructor::new( + params, + Some(calling_context), + self.async_info.clone(), + ) + .construct_partial(); } let mut analysis = DfAnalysis(self)