diff --git a/crates/flowistry/src/pdg/construct.rs b/crates/flowistry/src/pdg/construct.rs index f59f154c5..1b735407e 100644 --- a/crates/flowistry/src/pdg/construct.rs +++ b/crates/flowistry/src/pdg/construct.rs @@ -6,7 +6,7 @@ use flowistry_pdg::{CallString, GlobalLocation, RichLocation}; use itertools::Itertools; use log::{debug, trace}; use petgraph::graph::DiGraph; -use rustc_abi::FieldIdx; +use rustc_abi::{FieldIdx, VariantIdx}; use rustc_borrowck::consumers::{ places_conflict, BodyWithBorrowckFacts, PlaceConflictBias, }; @@ -14,9 +14,9 @@ use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::{ mir::{ - visit::Visitor, AggregateKind, BasicBlock, Body, Location, Operand, Place, PlaceElem, - Rvalue, Statement, StatementKind, Terminator, TerminatorEdges, TerminatorKind, - RETURN_PLACE, + visit::Visitor, AggregateKind, BasicBlock, Body, HasLocalDecls, Location, Operand, + Place, PlaceElem, Rvalue, Statement, StatementKind, Terminator, TerminatorEdges, + TerminatorKind, RETURN_PLACE, }, ty::{GenericArg, GenericArgsRef, List, ParamEnv, TyCtxt, TyKind}, }; @@ -201,6 +201,25 @@ struct CallingContext<'tcx> { call_stack: Vec, } +/// Stores ids that are needed to construct projections around async functions. +struct AsyncInfo { + poll_ready_variant_idx: VariantIdx, + poll_ready_field_idx: FieldIdx, +} + +impl AsyncInfo { + fn make(tcx: TyCtxt) -> Option> { + let lang_items = tcx.lang_items(); + let poll_def = tcx.adt_def(lang_items.poll()?); + let ready_vid = lang_items.poll_ready_variant()?; + assert_eq!(poll_def.variant_with_id(ready_vid).fields.len(), 1); + Some(Rc::new(Self { + poll_ready_variant_idx: poll_def.variant_index_with_id(ready_vid), + poll_ready_field_idx: 0_u32.into(), + })) + } +} + pub struct GraphConstructor<'tcx> { tcx: TyCtxt<'tcx>, params: PdgParams<'tcx>, @@ -212,6 +231,7 @@ pub struct GraphConstructor<'tcx> { body_assignments: utils::BodyAssignments, calling_context: Option>, start_loc: FxHashSet, + async_info: Rc, } macro_rules! trylet { @@ -226,11 +246,20 @@ macro_rules! trylet { impl<'tcx> GraphConstructor<'tcx> { /// Creates a [`GraphConstructor`] at the root of the PDG. pub fn root(params: PdgParams<'tcx>) -> Self { - GraphConstructor::new(params, None) + let tcx = params.tcx; + GraphConstructor::new( + params, + None, + AsyncInfo::make(tcx).expect("async functions are not defined"), + ) } /// Creates [`GraphConstructor`] for a function resolved as `fn_resolution` in a given `calling_context`. - fn new(params: PdgParams<'tcx>, calling_context: Option>) -> Self { + fn new( + params: PdgParams<'tcx>, + calling_context: Option>, + async_info: Rc, + ) -> Self { let tcx = params.tcx; let def_id = params.root.def_id().expect_local(); let body_with_facts = borrowck_facts::get_body_with_borrowck_facts(tcx, def_id); @@ -261,6 +290,7 @@ impl<'tcx> GraphConstructor<'tcx> { def_id, calling_context, body_assignments, + async_info, } } @@ -662,41 +692,57 @@ impl<'tcx> GraphConstructor<'tcx> { let parent_body = &self.body; let translate_to_parent = |child: Place<'tcx>| -> Option> { trace!(" Translating child place: {child:?}"); - let (parent_place, child_projection) = if child.local == RETURN_PLACE { - (destination, &child.projection[..]) - } else { - match call_kind { - // Map arguments to the argument array - CallKind::Direct => ( - args[child.local.as_usize() - 1].place()?, + let (parent_place, child_projection) = match call_kind { + // Async return must be handled special, because it gets wrapped in `Poll::Ready` + CallKind::AsyncPoll if child.local == RETURN_PLACE => { + let async_info = self.async_info.as_ref(); + let in_poll = destination.project_deeper( + &[PlaceElem::Downcast(None, async_info.poll_ready_variant_idx)], + tcx, + ); + let field_idx = async_info.poll_ready_field_idx; + let child_inner_return_type = in_poll + .ty(parent_body.local_decls(), tcx) + .field_ty(tcx, field_idx); + ( + in_poll.project_deeper( + &[PlaceElem::Field(field_idx, child_inner_return_type)], + tcx, + ), &child.projection[..], - ), - // Map arguments to projections of the future, the poll's first argument - CallKind::AsyncPoll => { - if child.local.as_usize() == 1 { - let PlaceElem::Field(idx, _) = child.projection[0] else { - panic!("Unexpected non-projection of async context") - }; - (args[idx.as_usize()].place()?, &child.projection[1 ..]) - } else { - return None; - } + ) + } + _ if child.local == RETURN_PLACE => (destination, &child.projection[..]), + // Map arguments to the argument array + CallKind::Direct => ( + args[child.local.as_usize() - 1].place()?, + &child.projection[..], + ), + // Map arguments to projections of the future, the poll's first argument + CallKind::AsyncPoll => { + if child.local.as_usize() == 1 { + let PlaceElem::Field(idx, _) = child.projection[0] else { + panic!("Unexpected non-projection of async context") + }; + (args[idx.as_usize()].place()?, &child.projection[1 ..]) + } else { + return None; } - // Map closure captures to the first argument. - // Map formal parameters to the second argument. - CallKind::Indirect => { - if child.local.as_usize() == 1 { - (args[0].place()?, &child.projection[..]) - } else { - let tuple_arg = args[1].place()?; - let _projection = child.projection.to_vec(); - let field = FieldIdx::from_usize(child.local.as_usize() - 2); - let field_ty = tuple_arg.ty(parent_body.as_ref(), tcx).field_ty(tcx, field); - ( - tuple_arg.project_deeper(&[PlaceElem::Field(field, field_ty)], tcx), - &child.projection[..], - ) - } + } + // Map closure captures to the first argument. + // Map formal parameters to the second argument. + CallKind::Indirect => { + if child.local.as_usize() == 1 { + (args[0].place()?, &child.projection[..]) + } else { + let tuple_arg = args[1].place()?; + let _projection = child.projection.to_vec(); + let field = FieldIdx::from_usize(child.local.as_usize() - 2); + let field_ty = tuple_arg.ty(parent_body.as_ref(), tcx).field_ty(tcx, field); + ( + tuple_arg.project_deeper(&[PlaceElem::Field(field, field_ty)], tcx), + &child.projection[..], + ) } } }; @@ -730,7 +776,8 @@ impl<'tcx> GraphConstructor<'tcx> { param_env, call_stack, }; - let child_constructor = GraphConstructor::new(params, Some(calling_context)); + let child_constructor = + GraphConstructor::new(params, Some(calling_context), self.async_info.clone()); if let Some(callback) = &self.params.call_change_callback { let info = CallInfo { @@ -941,7 +988,12 @@ impl<'tcx> GraphConstructor<'tcx> { call_string, call_stack, }; - return GraphConstructor::new(params, Some(calling_context)).construct_partial(); + return GraphConstructor::new( + params, + Some(calling_context), + self.async_info.clone(), + ) + .construct_partial(); } let mut analysis = DfAnalysis(self)