Skip to content

Commit

Permalink
Print Fn trait Output as function return, print binders, printing cle…
Browse files Browse the repository at this point in the history
…anup
  • Loading branch information
gavinleroy committed Jul 5, 2024
1 parent 835addc commit 639daa8
Show file tree
Hide file tree
Showing 17 changed files with 648 additions and 204 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 7 additions & 3 deletions crates/argus-ext/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait TyCtxtExt<'tcx> {

fn does_trait_ref_occur_in(
&self,
needle: ty::TraitRef<'tcx>,
needle: ty::PolyTraitRef<'tcx>,
haystack: ty::Predicate<'tcx>,
) -> bool;

Expand Down Expand Up @@ -130,14 +130,18 @@ pub fn retain_error_sources<'tcx, T>(

let is_implied_by_failing_bound = |other: &T| {
trait_preds.iter().any(|bound| {
let poly_tp = get_predicate(bound).expect_trait_predicate();
if let ty::TraitPredicate {
trait_ref,
polarity: ty::PredicatePolarity::Positive,
} = get_predicate(bound).expect_trait_predicate().skip_binder()
} = poly_tp.skip_binder()
// Don't consider reflexive implication
&& !eq(bound, other)
{
get_tcx(other).does_trait_ref_occur_in(trait_ref, get_predicate(other))
get_tcx(other).does_trait_ref_occur_in(
poly_tp.rebind(trait_ref),
get_predicate(other),
)
} else {
false
}
Expand Down
43 changes: 37 additions & 6 deletions crates/argus-ext/src/ty/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> {

fn does_trait_ref_occur_in(
&self,
needle: ty::TraitRef<'tcx>,
needle: ty::PolyTraitRef<'tcx>,
haystack: ty::Predicate<'tcx>,
) -> bool {
struct TraitRefVisitor<'tcx> {
tr: ty::TraitRef<'tcx>,
tr: ty::PolyTraitRef<'tcx>,
tcx: TyCtxt<'tcx>,
found: bool,
}
Expand All @@ -171,15 +171,16 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> {
def_id: DefId,
) -> bool {
let my_ty = self.tr.self_ty();
let my_id = self.tr.def_id;
let my_id = self.tr.def_id();

if args.is_empty() {
return false;
}

// FIXME: is it always the first type in the args list?
let proj_ty = args.type_at(0);
proj_ty == my_ty && self.tcx.is_descendant_of(def_id, my_id)
proj_ty == my_ty.skip_binder()
&& self.tcx.is_descendant_of(def_id, my_id)
}
}

Expand All @@ -200,13 +201,43 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> {
log::debug!("* [{predicate:#?}]");

if let ty::PredicateKind::Clause(ty::ClauseKind::Projection(
ty::ProjectionPredicate {
pp @ ty::ProjectionPredicate {
projection_term, ..
},
)) = predicate.kind().skip_binder()
{
self.found |= self
use rustc_infer::traits::util::supertraits;

// Check whether the `TraitRef`, or any implied supertrait
// appear in the projection. This can happen for example if we have
// a trait predicate `F: Fn(i32) -> i32`, the projection of the `Output`
// would be `<F as FnOnce(i32)>::Output == i32`.

let simple_check = self
.occurs_in_projection(projection_term.args, projection_term.def_id);
let deep_check = || {
let prj_ply_trait_ref = predicate.kind().rebind(pp);
let poly_supertrait_ref =
prj_ply_trait_ref.required_poly_trait_ref(self.tcx);
// Check whether `poly_supertrait_ref` is a supertrait of `self.tr`.
// HACK FIXME: this is too simplistic, it's unsound to check
// *just* that the `self_ty`s are equivalent and that the `def_id` is
// a super trait...
log::debug!(
"deep_check:\n {:?}\n to super\n {:?}",
self.tr,
poly_supertrait_ref
);
for super_ptr in supertraits(self.tcx, self.tr) {
log::debug!("* against {super_ptr:?}");
if super_ptr == poly_supertrait_ref {
return true;
}
}
false
};

self.found |= simple_check || deep_check();
}

predicate.super_visit_with(self);
Expand Down
2 changes: 2 additions & 0 deletions crates/argus-ser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ rustc_utils.workspace = true
serde.workspace = true
serde_json.workspace = true
smallvec = "1.11.2"
itertools = "0.12.0"
ts-rs = { version = "7.1.1", features = ["indexmap-impl"], optional = true }
index_vec = { version = "0.1.3", features = ["serde"] }
argus-ext = { version = "0.1.6", path = "../argus-ext" }

[dev-dependencies]
argus-ser = { path = ".", features = ["testing"] }
Expand Down
168 changes: 152 additions & 16 deletions crates/argus-ser/src/custom.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! Extensions to the type system for easier consumption.
use argus_ext::ty::TyCtxtExt;
use itertools::Itertools;
use rustc_data_structures::fx::FxIndexMap;
use rustc_hir::def_id::DefId;
use rustc_middle::ty;
use rustc_macros::TypeVisitable;
use rustc_middle::ty::{self, Upcast};
use serde::Serialize;
#[cfg(feature = "testing")]
use ts_rs::TS;
Expand Down Expand Up @@ -31,17 +34,72 @@ pub struct ImplHeader<'tcx> {
pub tys_without_default_bounds: Vec<ty::Ty<'tcx>>,
}

#[derive(Serialize)]
#[derive(Debug, Clone, TypeVisitable, Serialize)]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub struct GroupedClauses<'tcx> {
pub grouped: Vec<ClauseWithBounds<'tcx>>,
#[serde(with = "Slice__PolyClauseWithBoundsDef")]
#[cfg_attr(feature = "testing", ts(type = "PolyClauseWithBounds[]"))]
pub grouped: Vec<PolyClauseWithBounds<'tcx>>,

#[serde(with = "myty::Slice__ClauseDef")]
#[cfg_attr(feature = "testing", ts(type = "Clause[]"))]
pub other: Vec<ty::Clause<'tcx>>,
}

pub struct Slice__PolyClauseWithBoundsDef;
impl Slice__PolyClauseWithBoundsDef {
pub fn serialize<S>(
value: &[PolyClauseWithBounds],
s: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
struct Wrapper<'a, 'tcx: 'a>(
#[serde(with = "Binder__ClauseWithBounds")]
&'a PolyClauseWithBounds<'tcx>,
);

crate::serialize_custom_seq! { Wrapper, s, value }
}
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export, rename = "PolyClauseWithBounds"))]
pub struct Binder__ClauseWithBounds<'tcx> {
value: ClauseWithBounds<'tcx>,

#[serde(with = "myty::Slice__BoundVariableKindDef")]
#[cfg_attr(feature = "testing", ts(type = "BoundVariableKind[]"))]
bound_vars: &'tcx ty::List<ty::BoundVariableKind>,
}

impl<'tcx> Binder__ClauseWithBounds<'tcx> {
pub fn new(value: &PolyClauseWithBounds<'tcx>) -> Self {
Self {
bound_vars: value.bound_vars(),
value: value.clone().skip_binder(),
}
}

pub fn serialize<S>(
value: &PolyClauseWithBounds<'tcx>,
s: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Self::new(value).serialize(s)
}
}

type PolyClauseWithBounds<'tcx> = ty::Binder<'tcx, ClauseWithBounds<'tcx>>;

#[derive(Debug, Clone, TypeVisitable, Serialize)]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub struct ClauseWithBounds<'tcx> {
Expand All @@ -51,14 +109,24 @@ pub struct ClauseWithBounds<'tcx> {
pub bounds: Vec<ClauseBound<'tcx>>,
}

#[derive(Serialize)]
#[derive(Debug, Copy, Clone, TypeVisitable, Serialize)]
#[cfg_attr(feature = "testing", derive(TS))]
#[cfg_attr(feature = "testing", ts(export))]
pub enum ClauseBound<'tcx> {
Trait(
myty::Polarity,
#[serde(with = "myty::TraitRefPrintOnlyTraitPathDef")]
#[cfg_attr(feature = "testing", ts(type = "TraitRefPrintOnlyTraitPath"))]
crate::TraitRefPrintOnlyTraitPathDef<'tcx>,
ty::TraitRef<'tcx>,
),
FnTrait(
myty::Polarity,
#[serde(with = "myty::TraitRefPrintOnlyTraitPathDef")]
#[cfg_attr(feature = "testing", ts(type = "TraitRefPrintOnlyTraitPath"))]
ty::TraitRef<'tcx>,
#[serde(with = "myty::TyDef")]
#[cfg_attr(feature = "testing", ts(type = "Ty"))]
ty::Ty<'tcx>,
),
Region(
#[serde(with = "myty::RegionDef")]
Expand All @@ -68,37 +136,106 @@ pub enum ClauseBound<'tcx> {
}

pub fn group_predicates_by_ty<'tcx>(
tcx: ty::TyCtxt<'tcx>,
predicates: impl IntoIterator<Item = ty::Clause<'tcx>>,
) -> GroupedClauses<'tcx> {
// ARGUS: ADDITION: group predicates together based on `self_ty`.
let mut grouped: FxIndexMap<_, Vec<_>> = FxIndexMap::default();
let mut grouped = FxIndexMap::<_, Vec<_>>::default();
let mut other = vec![];

// TODO: this only looks at the output of an `FnOnce`, we probably also need
// to handle `AsyncFnOnce` and consider doing the same for the output of
// a `Future`. A further goal could be sugaring all associated type bounds
// back into the signature but that would require more work (not sure how much).
let fn_trait_output = tcx.lang_items().fn_once_output();
let mut fn_output_projections = vec![];

for p in predicates {
// TODO: all this binder skipping is a HACK.
if let Some(poly_trait_pred) = p.as_trait_clause() {
let ty = poly_trait_pred.self_ty().skip_binder();
let trait_ref =
poly_trait_pred.map_bound(|tr| tr.trait_ref).skip_binder();
let bound = ClauseBound::Trait(
poly_trait_pred.polarity().into(),
crate::TraitRefPrintOnlyTraitPathDef(trait_ref),
);
grouped.entry(ty).or_default().push(bound);
let bound =
ClauseBound::Trait(poly_trait_pred.polarity().into(), trait_ref);
grouped
.entry(ty)
.or_default()
.push(poly_trait_pred.rebind(bound));
} else if let Some(poly_ty_outl) = p.as_type_outlives_clause() {
let ty = poly_ty_outl.map_bound(|t| t.0).skip_binder();
let r = poly_ty_outl.map_bound(|t| t.1).skip_binder();
let bound = ClauseBound::Region(r);
grouped.entry(ty).or_default().push(bound);
grouped
.entry(ty)
.or_default()
.push(poly_ty_outl.rebind(bound));
} else if let Some(poly_projection) = p.as_projection_clause()
&& let Some(output_defid) = fn_trait_output
&& poly_projection.projection_def_id() == output_defid
{
fn_output_projections.push(poly_projection);
} else {
other.push(p);
}
}

let grouped = grouped
.into_iter()
.map(|(ty, bounds)| ClauseWithBounds { ty, bounds })
.map(|(ty, bounds)| {
// NOTE: we have to call unique to make a `List`
let all_bound_vars = bounds.iter().flat_map(|b| b.bound_vars()).unique();
let bound_vars = tcx.mk_bound_variable_kinds_from_iter(all_bound_vars);
let unbounds = bounds
.into_iter()
.map(|bclause| {
let clause = bclause.skip_binder();
if let ClauseBound::Trait(p, tref) = clause
&& tcx.is_fn_trait(tref.def_id)
&& let poly_tr = bclause.rebind(tref)
&& let matching_projections = fn_output_projections
.extract_if(move |p| {
tcx.does_trait_ref_occur_in(
poly_tr,
p.map_bound(|p| {
ty::PredicateKind::Clause(ty::ClauseKind::Projection(p))
})
.upcast(tcx),
)
})
.unique()
.collect::<smallvec::SmallVec<[_; 2]>>()
&& !matching_projections.is_empty()
{
log::debug!(
"Matching projections for {bclause:?} {matching_projections:#?}"
);
let ret_ty = matching_projections[0]
.term()
.skip_binder()
.ty()
.expect("FnOnce::Output Ty");
debug_assert!(matching_projections.len() == 1);
ClauseBound::FnTrait(p, tref, ret_ty)
} else {
clause
}
})
.collect();
ty::Binder::bind_with_vars(
ClauseWithBounds {
ty,
bounds: unbounds,
},
bound_vars,
)
})
.collect::<Vec<_>>();

assert!(
fn_output_projections.is_empty(),
"Remaining output projections {fn_output_projections:#?}"
);

GroupedClauses { grouped, other }
}

Expand Down Expand Up @@ -148,7 +285,7 @@ pub fn get_opt_impl_header(
log::debug!("pretty predicates for impl header {:#?}", pretty_predicates);

// Argus addition
let grouped_clauses = group_predicates_by_ty(pretty_predicates);
let grouped_clauses = group_predicates_by_ty(tcx, pretty_predicates);

let tys_without_default_bounds =
types_without_default_bounds.into_iter().collect::<Vec<_>>();
Expand All @@ -158,7 +295,6 @@ pub fn get_opt_impl_header(
name,
self_ty,
predicates: grouped_clauses,
// predicates: pretty_predicates,
tys_without_default_bounds,
})
}
Loading

0 comments on commit 639daa8

Please sign in to comment.