diff --git a/Cargo.lock b/Cargo.lock index e9d8a70..0c94d02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,9 +83,11 @@ name = "argus-ser" version = "0.1.6" dependencies = [ "anyhow", + "argus-ext", "argus-ser", "fluid-let", "index_vec", + "itertools", "log", "rustc_utils", "serde", diff --git a/crates/argus-ext/src/ty.rs b/crates/argus-ext/src/ty.rs index 12e8e47..f342f9b 100644 --- a/crates/argus-ext/src/ty.rs +++ b/crates/argus-ext/src/ty.rs @@ -60,7 +60,7 @@ pub trait TyCtxtExt<'tcx> { fn does_trait_ref_occur_in( &self, - needle: ty::TraitRef<'tcx>, + needle: ty::PolyTraitRef<'tcx>, haystack: ty::Predicate<'tcx>, ) -> bool; @@ -130,14 +130,18 @@ pub fn retain_error_sources<'tcx, T>( let is_implied_by_failing_bound = |other: &T| { trait_preds.iter().any(|bound| { + let poly_tp = get_predicate(bound).expect_trait_predicate(); if let ty::TraitPredicate { trait_ref, polarity: ty::PredicatePolarity::Positive, - } = get_predicate(bound).expect_trait_predicate().skip_binder() + } = poly_tp.skip_binder() // Don't consider reflexive implication && !eq(bound, other) { - get_tcx(other).does_trait_ref_occur_in(trait_ref, get_predicate(other)) + get_tcx(other).does_trait_ref_occur_in( + poly_tp.rebind(trait_ref), + get_predicate(other), + ) } else { false } diff --git a/crates/argus-ext/src/ty/impl.rs b/crates/argus-ext/src/ty/impl.rs index 3689b89..b85e814 100644 --- a/crates/argus-ext/src/ty/impl.rs +++ b/crates/argus-ext/src/ty/impl.rs @@ -155,11 +155,11 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> { fn does_trait_ref_occur_in( &self, - needle: ty::TraitRef<'tcx>, + needle: ty::PolyTraitRef<'tcx>, haystack: ty::Predicate<'tcx>, ) -> bool { struct TraitRefVisitor<'tcx> { - tr: ty::TraitRef<'tcx>, + tr: ty::PolyTraitRef<'tcx>, tcx: TyCtxt<'tcx>, found: bool, } @@ -171,7 +171,7 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> { def_id: DefId, ) -> bool { let my_ty = self.tr.self_ty(); - let my_id = self.tr.def_id; + let my_id = self.tr.def_id(); if args.is_empty() { return false; @@ -179,7 +179,8 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> { // FIXME: is it always the first type in the args list? let proj_ty = args.type_at(0); - proj_ty == my_ty && self.tcx.is_descendant_of(def_id, my_id) + proj_ty == my_ty.skip_binder() + && self.tcx.is_descendant_of(def_id, my_id) } } @@ -200,13 +201,43 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> { log::debug!("* [{predicate:#?}]"); if let ty::PredicateKind::Clause(ty::ClauseKind::Projection( - ty::ProjectionPredicate { + pp @ ty::ProjectionPredicate { projection_term, .. }, )) = predicate.kind().skip_binder() { - self.found |= self + use rustc_infer::traits::util::supertraits; + + // Check whether the `TraitRef`, or any implied supertrait + // appear in the projection. This can happen for example if we have + // a trait predicate `F: Fn(i32) -> i32`, the projection of the `Output` + // would be `::Output == i32`. + + let simple_check = self .occurs_in_projection(projection_term.args, projection_term.def_id); + let deep_check = || { + let prj_ply_trait_ref = predicate.kind().rebind(pp); + let poly_supertrait_ref = + prj_ply_trait_ref.required_poly_trait_ref(self.tcx); + // Check whether `poly_supertrait_ref` is a supertrait of `self.tr`. + // HACK FIXME: this is too simplistic, it's unsound to check + // *just* that the `self_ty`s are equivalent and that the `def_id` is + // a super trait... + log::debug!( + "deep_check:\n {:?}\n to super\n {:?}", + self.tr, + poly_supertrait_ref + ); + for super_ptr in supertraits(self.tcx, self.tr) { + log::debug!("* against {super_ptr:?}"); + if super_ptr == poly_supertrait_ref { + return true; + } + } + false + }; + + self.found |= simple_check || deep_check(); } predicate.super_visit_with(self); diff --git a/crates/argus-ser/Cargo.toml b/crates/argus-ser/Cargo.toml index e26585e..ef1ccd7 100644 --- a/crates/argus-ser/Cargo.toml +++ b/crates/argus-ser/Cargo.toml @@ -25,8 +25,10 @@ rustc_utils.workspace = true serde.workspace = true serde_json.workspace = true smallvec = "1.11.2" +itertools = "0.12.0" ts-rs = { version = "7.1.1", features = ["indexmap-impl"], optional = true } index_vec = { version = "0.1.3", features = ["serde"] } +argus-ext = { version = "0.1.6", path = "../argus-ext" } [dev-dependencies] argus-ser = { path = ".", features = ["testing"] } diff --git a/crates/argus-ser/src/custom.rs b/crates/argus-ser/src/custom.rs index 2c5c160..0eb8ddc 100644 --- a/crates/argus-ser/src/custom.rs +++ b/crates/argus-ser/src/custom.rs @@ -1,7 +1,10 @@ //! Extensions to the type system for easier consumption. +use argus_ext::ty::TyCtxtExt; +use itertools::Itertools; use rustc_data_structures::fx::FxIndexMap; use rustc_hir::def_id::DefId; -use rustc_middle::ty; +use rustc_macros::TypeVisitable; +use rustc_middle::ty::{self, Upcast}; use serde::Serialize; #[cfg(feature = "testing")] use ts_rs::TS; @@ -31,17 +34,72 @@ pub struct ImplHeader<'tcx> { pub tys_without_default_bounds: Vec>, } -#[derive(Serialize)] +#[derive(Debug, Clone, TypeVisitable, Serialize)] #[cfg_attr(feature = "testing", derive(TS))] #[cfg_attr(feature = "testing", ts(export))] pub struct GroupedClauses<'tcx> { - pub grouped: Vec>, + #[serde(with = "Slice__PolyClauseWithBoundsDef")] + #[cfg_attr(feature = "testing", ts(type = "PolyClauseWithBounds[]"))] + pub grouped: Vec>, + #[serde(with = "myty::Slice__ClauseDef")] #[cfg_attr(feature = "testing", ts(type = "Clause[]"))] pub other: Vec>, } +pub struct Slice__PolyClauseWithBoundsDef; +impl Slice__PolyClauseWithBoundsDef { + pub fn serialize( + value: &[PolyClauseWithBounds], + s: S, + ) -> Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + struct Wrapper<'a, 'tcx: 'a>( + #[serde(with = "Binder__ClauseWithBounds")] + &'a PolyClauseWithBounds<'tcx>, + ); + + crate::serialize_custom_seq! { Wrapper, s, value } + } +} + #[derive(Serialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "testing", derive(TS))] +#[cfg_attr(feature = "testing", ts(export, rename = "PolyClauseWithBounds"))] +pub struct Binder__ClauseWithBounds<'tcx> { + value: ClauseWithBounds<'tcx>, + + #[serde(with = "myty::Slice__BoundVariableKindDef")] + #[cfg_attr(feature = "testing", ts(type = "BoundVariableKind[]"))] + bound_vars: &'tcx ty::List, +} + +impl<'tcx> Binder__ClauseWithBounds<'tcx> { + pub fn new(value: &PolyClauseWithBounds<'tcx>) -> Self { + Self { + bound_vars: value.bound_vars(), + value: value.clone().skip_binder(), + } + } + + pub fn serialize( + value: &PolyClauseWithBounds<'tcx>, + s: S, + ) -> Result + where + S: serde::Serializer, + { + Self::new(value).serialize(s) + } +} + +type PolyClauseWithBounds<'tcx> = ty::Binder<'tcx, ClauseWithBounds<'tcx>>; + +#[derive(Debug, Clone, TypeVisitable, Serialize)] #[cfg_attr(feature = "testing", derive(TS))] #[cfg_attr(feature = "testing", ts(export))] pub struct ClauseWithBounds<'tcx> { @@ -51,14 +109,24 @@ pub struct ClauseWithBounds<'tcx> { pub bounds: Vec>, } -#[derive(Serialize)] +#[derive(Debug, Copy, Clone, TypeVisitable, Serialize)] #[cfg_attr(feature = "testing", derive(TS))] #[cfg_attr(feature = "testing", ts(export))] pub enum ClauseBound<'tcx> { Trait( myty::Polarity, + #[serde(with = "myty::TraitRefPrintOnlyTraitPathDef")] #[cfg_attr(feature = "testing", ts(type = "TraitRefPrintOnlyTraitPath"))] - crate::TraitRefPrintOnlyTraitPathDef<'tcx>, + ty::TraitRef<'tcx>, + ), + FnTrait( + myty::Polarity, + #[serde(with = "myty::TraitRefPrintOnlyTraitPathDef")] + #[cfg_attr(feature = "testing", ts(type = "TraitRefPrintOnlyTraitPath"))] + ty::TraitRef<'tcx>, + #[serde(with = "myty::TyDef")] + #[cfg_attr(feature = "testing", ts(type = "Ty"))] + ty::Ty<'tcx>, ), Region( #[serde(with = "myty::RegionDef")] @@ -68,27 +136,44 @@ pub enum ClauseBound<'tcx> { } pub fn group_predicates_by_ty<'tcx>( + tcx: ty::TyCtxt<'tcx>, predicates: impl IntoIterator>, ) -> GroupedClauses<'tcx> { // ARGUS: ADDITION: group predicates together based on `self_ty`. - let mut grouped: FxIndexMap<_, Vec<_>> = FxIndexMap::default(); + let mut grouped = FxIndexMap::<_, Vec<_>>::default(); let mut other = vec![]; + + // TODO: this only looks at the output of an `FnOnce`, we probably also need + // to handle `AsyncFnOnce` and consider doing the same for the output of + // a `Future`. A further goal could be sugaring all associated type bounds + // back into the signature but that would require more work (not sure how much). + let fn_trait_output = tcx.lang_items().fn_once_output(); + let mut fn_output_projections = vec![]; + for p in predicates { - // TODO: all this binder skipping is a HACK. if let Some(poly_trait_pred) = p.as_trait_clause() { let ty = poly_trait_pred.self_ty().skip_binder(); let trait_ref = poly_trait_pred.map_bound(|tr| tr.trait_ref).skip_binder(); - let bound = ClauseBound::Trait( - poly_trait_pred.polarity().into(), - crate::TraitRefPrintOnlyTraitPathDef(trait_ref), - ); - grouped.entry(ty).or_default().push(bound); + let bound = + ClauseBound::Trait(poly_trait_pred.polarity().into(), trait_ref); + grouped + .entry(ty) + .or_default() + .push(poly_trait_pred.rebind(bound)); } else if let Some(poly_ty_outl) = p.as_type_outlives_clause() { let ty = poly_ty_outl.map_bound(|t| t.0).skip_binder(); let r = poly_ty_outl.map_bound(|t| t.1).skip_binder(); let bound = ClauseBound::Region(r); - grouped.entry(ty).or_default().push(bound); + grouped + .entry(ty) + .or_default() + .push(poly_ty_outl.rebind(bound)); + } else if let Some(poly_projection) = p.as_projection_clause() + && let Some(output_defid) = fn_trait_output + && poly_projection.projection_def_id() == output_defid + { + fn_output_projections.push(poly_projection); } else { other.push(p); } @@ -96,9 +181,61 @@ pub fn group_predicates_by_ty<'tcx>( let grouped = grouped .into_iter() - .map(|(ty, bounds)| ClauseWithBounds { ty, bounds }) + .map(|(ty, bounds)| { + // NOTE: we have to call unique to make a `List` + let all_bound_vars = bounds.iter().flat_map(|b| b.bound_vars()).unique(); + let bound_vars = tcx.mk_bound_variable_kinds_from_iter(all_bound_vars); + let unbounds = bounds + .into_iter() + .map(|bclause| { + let clause = bclause.skip_binder(); + if let ClauseBound::Trait(p, tref) = clause + && tcx.is_fn_trait(tref.def_id) + && let poly_tr = bclause.rebind(tref) + && let matching_projections = fn_output_projections + .extract_if(move |p| { + tcx.does_trait_ref_occur_in( + poly_tr, + p.map_bound(|p| { + ty::PredicateKind::Clause(ty::ClauseKind::Projection(p)) + }) + .upcast(tcx), + ) + }) + .unique() + .collect::>() + && !matching_projections.is_empty() + { + log::debug!( + "Matching projections for {bclause:?} {matching_projections:#?}" + ); + let ret_ty = matching_projections[0] + .term() + .skip_binder() + .ty() + .expect("FnOnce::Output Ty"); + debug_assert!(matching_projections.len() == 1); + ClauseBound::FnTrait(p, tref, ret_ty) + } else { + clause + } + }) + .collect(); + ty::Binder::bind_with_vars( + ClauseWithBounds { + ty, + bounds: unbounds, + }, + bound_vars, + ) + }) .collect::>(); + assert!( + fn_output_projections.is_empty(), + "Remaining output projections {fn_output_projections:#?}" + ); + GroupedClauses { grouped, other } } @@ -148,7 +285,7 @@ pub fn get_opt_impl_header( log::debug!("pretty predicates for impl header {:#?}", pretty_predicates); // Argus addition - let grouped_clauses = group_predicates_by_ty(pretty_predicates); + let grouped_clauses = group_predicates_by_ty(tcx, pretty_predicates); let tys_without_default_bounds = types_without_default_bounds.into_iter().collect::>(); @@ -158,7 +295,6 @@ pub fn get_opt_impl_header( name, self_ty, predicates: grouped_clauses, - // predicates: pretty_predicates, tys_without_default_bounds, }) } diff --git a/crates/argus-ser/src/lib.rs b/crates/argus-ser/src/lib.rs index 0ff13b0..1978686 100644 --- a/crates/argus-ser/src/lib.rs +++ b/crates/argus-ser/src/lib.rs @@ -28,13 +28,16 @@ let_chains, if_let_guard, decl_macro, + extract_if, associated_type_defaults )] #![allow(non_camel_case_types, non_snake_case)] extern crate rustc_apfloat; +extern crate rustc_ast_ir; extern crate rustc_data_structures; extern crate rustc_hir; extern crate rustc_infer; +extern crate rustc_macros; extern crate rustc_middle; extern crate rustc_span; extern crate rustc_target; @@ -98,19 +101,18 @@ impl<'tcx> InferCtxtSerializeExt for InferCtxt<'tcx> { } } +#[macro_export] macro_rules! serialize_custom_seq { ($wrap:ident, $serializer:expr, $value:expr) => {{ use serde::ser::SerializeSeq; let mut seq = $serializer.serialize_seq(Some($value.len()))?; - for e in $value.iter() { + for e in $value.into_iter() { seq.serialize_element(&$wrap(e))?; } seq.end() }}; } -pub(crate) use serialize_custom_seq; - // ---------------------------------------- // Parameters diff --git a/crates/argus-ser/src/path/mod.rs b/crates/argus-ser/src/path/mod.rs index 4be39de..94f0bbc 100644 --- a/crates/argus-ser/src/path/mod.rs +++ b/crates/argus-ser/src/path/mod.rs @@ -102,11 +102,16 @@ enum PathSegment<'tcx> { GenericDelimiters { inner: Vec>, }, // < ... > - CommaSeparated { - #[cfg_attr(feature = "testing", ts(type = "any[]"))] - entries: Vec, - kind: CommaSeparatedKind, - }, // ..., ..., ... + GenericArgumentList { + #[serde(with = "serial_ty::Slice__GenericArgDef")] + #[cfg_attr(feature = "testing", ts(type = "GenericArg[]"))] + entries: Vec>, + }, + // CommaSeparated { + // #[cfg_attr(feature = "testing", ts(type = "any[]"))] + // entries: Vec, + // kind: CommaSeparatedKind, + // }, // ..., ..., ... Impl { #[cfg_attr(feature = "testing", ts(type = "DefinedPath"))] #[serde(skip_serializing_if = "Option::is_none")] diff --git a/crates/argus-ser/src/path/pretty.rs b/crates/argus-ser/src/path/pretty.rs index 4607099..f65c0d2 100644 --- a/crates/argus-ser/src/path/pretty.rs +++ b/crates/argus-ser/src/path/pretty.rs @@ -16,12 +16,9 @@ use rustc_hir::{ use rustc_middle::ty::{self, *}; use rustc_span::symbol::Ident; use rustc_utils::source_map::range::CharRange; -use serde::Serialize; -use super::{ - super::ty::{GenericArgDef, TraitRefPrintOnlyTraitPathDef}, - *, -}; +use super::*; +use crate::ty::TraitRefPrintOnlyTraitPathDef; impl<'tcx> PathBuilder<'tcx> { pub fn print_def_path( @@ -148,13 +145,16 @@ impl<'tcx> PathBuilder<'tcx> { // CHANGE: write!(self, "::")?; self.segments.push(PathSegment::Colons); } - self.generic_delimiters(|cx| { - #[derive(Serialize)] - struct Wrapper<'a, 'tcx: 'a>( - #[serde(with = "GenericArgDef")] &'a GenericArg<'tcx>, - ); - cx.comma_sep(args.iter().map(Wrapper), CommaSeparatedKind::GenericArg); - }); + self.segments.push(PathSegment::GenericArgumentList { + entries: args.iter().copied().collect(), + }) + // self.generic_delimiters(|cx| { + // #[derive(Serialize)] + // struct Wrapper<'a, 'tcx: 'a>( + // #[serde(with = "GenericArgDef")] &'a GenericArg<'tcx>, + // ); + // cx.comma_sep(args.iter().map(Wrapper), CommaSeparatedKind::GenericArg); + // }); } } @@ -220,33 +220,35 @@ impl<'tcx> PathBuilder<'tcx> { // CHANGE: write!(self, ">")?; } - /// Prints comma-separated elements. - fn comma_sep( - &mut self, - elems: impl Iterator, - kind: CommaSeparatedKind, - ) where - T: Serialize, - // T: Print<'tcx, Self>, - { - // CHANGE: - // if let Some(first) = elems.next() { - // // first.print(self)?; - // for elem in elems { - // // self.write_str(", ")?; - // // elem.print(self)?; - // } - // } - self.segments.push(PathSegment::CommaSeparated { - entries: elems - .map(|e| { - serde_json::to_value(e) - .expect("failed to serialize comma separated value") - }) - .collect::>(), - kind, - }); - } + // NOTE: this was only used for generic argument lists, for now it's been removed, + // in the future if we need a generic way of including commas you can bring it back. + // /// Prints comma-separated elements. + // fn comma_sep( + // &mut self, + // elems: impl Iterator, + // kind: CommaSeparatedKind, + // ) where + // T: Serialize, + // // T: Print<'tcx, Self>, + // { + // // CHANGE: + // // if let Some(first) = elems.next() { + // // // first.print(self)?; + // // for elem in elems { + // // // self.write_str(", ")?; + // // // elem.print(self)?; + // // } + // // } + // self.segments.push(PathSegment::CommaSeparated { + // entries: elems + // .map(|e| { + // serde_json::to_value(e) + // .expect("failed to serialize comma separated value") + // }) + // .collect::>(), + // kind, + // }); + // } pub fn path_append_impl( &mut self, diff --git a/crates/argus-ser/src/ty.rs b/crates/argus-ser/src/ty.rs index 2de849a..056a744 100644 --- a/crates/argus-ser/src/ty.rs +++ b/crates/argus-ser/src/ty.rs @@ -1,6 +1,7 @@ use rustc_data_structures::fx::FxIndexMap; use rustc_hir::{self as hir, def::DefKind, def_id::DefId, LangItem, Safety}; use rustc_infer::traits::{ObligationCause, PredicateObligation}; +use rustc_macros::TypeVisitable; use rustc_middle::{traits::util::supertraits_for_pretty_printing, ty}; use rustc_span::symbol::{kw, Symbol}; use rustc_target::spec::abi::Abi; @@ -591,6 +592,7 @@ impl PlaceholderTyDef { // Function signature definitions #[derive(Serialize)] +#[serde(rename_all = "camelCase")] #[cfg_attr(feature = "testing", derive(TS))] #[cfg_attr(feature = "testing", ts(export, rename = "PolyFnSig"))] pub struct Binder__FnSigDef<'tcx> { @@ -788,9 +790,9 @@ impl BoundTyDef { } } -// -------------------------------------------- -// -------------------------------------------- -// TODO: the DefId's here need to be dealt with +// ================================================== +// VV TODO: the DefId's here need to be dealt with VV +// ================================================== #[derive(Serialize)] #[serde(remote = "ty::BoundVariableKind")] @@ -799,7 +801,7 @@ impl BoundTyDef { pub enum BoundVariableKindDef { Ty( #[serde(with = "BoundTyKindDef")] - #[cfg_attr(feature = "testing", ts(type = "any"))] + #[cfg_attr(feature = "testing", ts(type = "BoundTyKind"))] ty::BoundTyKind, ), Region( @@ -812,7 +814,7 @@ pub enum BoundVariableKindDef { pub struct Slice__BoundVariableKindDef; impl Slice__BoundVariableKindDef { - fn serialize( + pub fn serialize( value: &[ty::BoundVariableKind], s: S, ) -> Result @@ -844,16 +846,21 @@ pub enum BoundRegionKindDef { #[derive(Serialize)] #[serde(remote = "ty::BoundTyKind")] -// #[cfg_attr(feature = "testing", derive(TS))] -// #[cfg_attr(feature = "testing", ts(export, rename = "BoundTyKind"))] +#[cfg_attr(feature = "testing", derive(TS))] +#[cfg_attr(feature = "testing", ts(export, rename = "BoundTyKind"))] pub enum BoundTyKindDef { Anon, - Param(#[serde(skip)] DefId, #[serde(skip)] Symbol), + Param( + #[serde(skip)] DefId, + #[serde(with = "SymbolDef")] + #[cfg_attr(feature = "testing", ts(type = "Symbol"))] + Symbol, + ), } -// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +// ============================================================ +// ^^^^^^^^^ Above comment applies within this range ^^^^^^^^^^ +// ============================================================ #[derive(Serialize)] #[serde(remote = "ty::IntTy")] @@ -1167,7 +1174,11 @@ pub struct ParamEnvDef<'tcx>(crate::custom::GroupedClauses<'tcx>); impl<'tcx> ParamEnvDef<'tcx> { pub fn new(value: &ty::ParamEnv<'tcx>) -> Self { - Self(crate::custom::group_predicates_by_ty(value.caller_bounds())) + let tcx = InferCtxt::access(|cx| cx.tcx); + Self(crate::custom::group_predicates_by_ty( + tcx, + value.caller_bounds(), + )) } pub fn serialize( @@ -1203,6 +1214,7 @@ impl<'tcx> PredicateDef<'tcx> { } #[derive(Serialize)] +#[serde(rename_all = "camelCase")] #[cfg_attr(feature = "testing", derive(TS))] #[cfg_attr(feature = "testing", ts(export, rename = "PolyPredicateKind"))] pub struct Binder__PredicateKind<'tcx> { @@ -1551,7 +1563,7 @@ impl<'tcx> TraitRefPrintOnlyTraitPathDef<'tcx> { } } -#[derive(Serialize)] +#[derive(Debug, Copy, Clone, TypeVisitable, Serialize)] #[cfg_attr(feature = "testing", derive(TS))] #[cfg_attr(feature = "testing", ts(export))] pub enum Polarity { diff --git a/ide/packages/common/src/func.ts b/ide/packages/common/src/func.ts index 49ffffe..3148455 100644 --- a/ide/packages/common/src/func.ts +++ b/ide/packages/common/src/func.ts @@ -1,9 +1,14 @@ import _ from "lodash"; import type { + BoundRegionKind, + BoundTyKind, + BoundVariableKind, CharRange, + GenericArg, ObligationHash, Predicate, + Region, Ty, TyVal } from "./bindings"; @@ -125,7 +130,7 @@ export function takeRightUntil(arr: T[], pred: (t: T) => boolean) { export type Unit = { Tuple: Ty[] }; -export function tyIsUnit(o: TyVal): o is Unit { +export function isUnitTy(o: TyVal): o is Unit { return isObject(o) && "Tuple" in o && o.Tuple.length === 0; } @@ -151,3 +156,27 @@ export function fnInputsAndOutput(args: T[]): [T[], T] { let output = _.last(args)!; return [inputs, output]; } + +export const isNamedRegion = (r: Region) => r.type === "Named"; + +export function isNamedGenericArg(ga: GenericArg) { + return ("Lifetime" in ga) ? isNamedRegion(ga.Lifetime) : true; +} + +export const isNamedBoundRegion = (br: BoundRegionKind) => + isObject(br) && "BrNamed" in br && br.BrNamed[0] !== "'_"; + +export const isNamedBoundTy = (bt: BoundTyKind) => + isObject(bt) && "Param" in bt; + +export function isNamedBoundVariable(bv: BoundVariableKind) { + if (isObject(bv)) { + if ("Region" in bv) { + return isNamedBoundRegion(bv.Region); + } else if ("Ty" in bv) { + return isNamedBoundTy(bv.Ty); + } + } + + return false; +} diff --git a/ide/packages/print/src/private/argus.tsx b/ide/packages/print/src/private/argus.tsx index 9ccff13..249a245 100644 --- a/ide/packages/print/src/private/argus.tsx +++ b/ide/packages/print/src/private/argus.tsx @@ -5,22 +5,29 @@ import type { ImplHeader, Ty } from "@argus/common/bindings"; -import { anyElems } from "@argus/common/func"; +import { anyElems, isUnitTy } from "@argus/common/func"; import _ from "lodash"; -import React from "react"; +import React, { useContext } from "react"; import { Toggle } from "../Toggle"; -import { AllowProjectionSubst } from "../context"; +import { AllowProjectionSubst, TyCtxt } from "../context"; import { PrintDefPath } from "./path"; import { PrintClause } from "./predicate"; -import { Angled, CommaSeparated, Kw, PlusSeparated } from "./syntax"; -import { PrintGenericArg, PrintPolarity, PrintRegion, PrintTy } from "./ty"; +import { Angled, CommaSeparated, Kw, PlusSeparated, nbsp } from "./syntax"; +import { + PrintBinder, + PrintGenericArg, + PrintPolarity, + PrintRegion, + PrintTy, + PrintTyKind +} from "./ty"; // NOTE: it looks ugly, but we need to disable projection substitution for all parts // of the impl blocks. export const PrintImplHeader = ({ o }: { o: ImplHeader }) => { console.debug("Printing ImplHeader", o); - const genArgs = _.map(o.args, arg => () => ( + const genArgs = _.map(o.args, arg => ( @@ -40,7 +47,8 @@ export const PrintImplHeader = ({ o }: { o: ImplHeader }) => { return ( impl - {argsWAngle} for{" "} + {argsWAngle} for + {nbsp} { export const PrintGroupedClauses = ({ o }: { o: GroupedClauses }) => { console.debug("Printing GroupedClauses", o); + const Inner = ({ value }: { value: ClauseWithBounds }) => ( + + ); const groupedClauses = _.map(o.grouped, (group, idx) => (
- +
)); const noGroupedClauses = _.map(o.other, (clause, idx) => ( @@ -95,17 +106,17 @@ export const PrintWhereClause = ({ return ( <> {" "} - where + where + {nbsp} + ); }; const PrintClauseWithBounds = ({ o }: { o: ClauseWithBounds }) => { const [traits, lifetimes] = _.partition(o.bounds, bound => "Trait" in bound); - const traitBounds = _.map(traits, bound => () => ( - - )); - const lifetimeBounds = _.map(lifetimes, bound => () => ( + const traitBounds = _.map(traits, bound => ); + const lifetimeBounds = _.map(lifetimes, bound => ( )); const boundComponents = _.concat(traitBounds, lifetimeBounds); @@ -118,7 +129,26 @@ const PrintClauseWithBounds = ({ o }: { o: ClauseWithBounds }) => { }; const PrintClauseBound = ({ o }: { o: ClauseBound }) => { - if ("Trait" in o) { + const tyCtxt = useContext(TyCtxt)!; + if ("FnTrait" in o) { + const [polarity, path, res] = o.FnTrait; + const tyKind = tyCtxt.interner[res]; + const arrow = isUnitTy(tyKind) ? null : ( + <> + {nbsp} + {"->"} + {nbsp} + + + ); + return ( + <> + + + {arrow} + + ); + } else if ("Trait" in o) { const [polarity, path] = o.Trait; return ( <> @@ -126,9 +156,9 @@ const PrintClauseBound = ({ o }: { o: ClauseBound }) => { ); - } - if ("Region" in o) { + } else if ("Region" in o) { return ; } + throw new Error("Unknown clause bound", o); }; diff --git a/ide/packages/print/src/private/path.tsx b/ide/packages/print/src/private/path.tsx index 0c85c1e..26456ef 100644 --- a/ide/packages/print/src/private/path.tsx +++ b/ide/packages/print/src/private/path.tsx @@ -1,10 +1,10 @@ import type { DefinedPath, PathSegment } from "@argus/common/bindings"; -import { takeRightUntil } from "@argus/common/func"; +import { isNamedGenericArg, takeRightUntil } from "@argus/common/func"; import _ from "lodash"; import React, { useContext } from "react"; import { Toggle } from "../Toggle"; import { AllowPathTrim, AllowToggle, DefPathRender, TyCtxt } from "../context"; -import { Angled, CommaSeparated, Kw } from "./syntax"; +import { Angled, CommaSeparated, Kw, nbsp } from "./syntax"; import { PrintGenericArg, PrintTy } from "./ty"; // Special case the printing for associated types. Things that look like @@ -151,56 +151,72 @@ export const PrintPathSegment = ({ o }: { o: PathSegment }) => { // TODO: we should actually print something here (or send the file snippet). return impl@{o.range.toString()}; } + // General case of wrapping segments in angle brackets. case "GenericDelimiters": { // We don't want empty <> on the end of types if (o.inner.length === 0) { return null; } - // Use a metric of "type size" rather than inner lenght. - const useToggle = useContext(AllowToggle) && o.inner.length > 3; + return ( - // TODO: do we want to allow nested toggles? - - - {useToggle ? ( - } - /> - ) : ( - - )} - - + 3} + Elem={() => } + /> ); } - case "CommaSeparated": { - const Mapper = - o.kind.type === "GenericArg" - ? PrintGenericArg - : ({ o }: { o: any }) => { - throw new Error("Unknown comma separated kind", o); - }; - const components = _.map(o.entries, entry => () => ); - return ; + // Angle brackets used *specifically* for a list of generic arguments. + case "GenericArgumentList": { + const namedArgs = _.filter(o.entries, isNamedGenericArg); + if (namedArgs.length === 0) { + return null; + } + + const components = _.map(namedArgs, (arg, i) => ( + + )); + return ( + 3} + Elem={() => } + /> + ); } default: throw new Error("Unknown path segment", o); } }; +// NOTE: used as a helper for the `GenericDelimiters` and `GenericArgumentList` segments. +const PrintInToggleableEnvironment = ({ + bypassToggle, + Elem +}: { bypassToggle: boolean; Elem: React.FC }) => { + // Use a metric of "type size" rather than inner lenght. + const useToggle = useContext(AllowToggle) && bypassToggle; + return ( + // TODO: do we want to allow nested toggles? + + {useToggle ? } /> : } + + ); +}; + // export const PrintImplFor = ({ path, ty }: { path?: DefinedPath; ty: any }) => { const p = path === undefined ? null : ( <> - for{" "} + for + {nbsp} ); return ( <> - impl {p} + impl + {nbsp} + {p} ); @@ -211,8 +227,8 @@ export const PrintImplAs = ({ path, ty }: { path?: DefinedPath; ty: any }) => { const p = path === undefined ? null : ( <> - {" "} - as + {nbsp} + as ); diff --git a/ide/packages/print/src/private/predicate.tsx b/ide/packages/print/src/private/predicate.tsx index 564e6e5..09778dc 100644 --- a/ide/packages/print/src/private/predicate.tsx +++ b/ide/packages/print/src/private/predicate.tsx @@ -62,8 +62,10 @@ export const PrintParamEnv = ({ o }: { o: ParamEnv }) => { }; export const PrintBinderPredicateKind = ({ o }: { o: PolyPredicateKind }) => { - const inner = (o: PredicateKind) => ; - return ; + const Inner = ({ value }: { value: PredicateKind }) => ( + + ); + return ; }; export const PrintPredicateKind = ({ o }: { o: PredicateKind }) => { @@ -134,8 +136,10 @@ export const PrintAliasRelationDirection = ({ }; export const PrintClause = ({ o }: { o: Clause }) => { - const inner = (o: ClauseKind) => ; - return ; + const Inner = ({ value }: { value: ClauseKind }) => ( + + ); + return ; }; export const PrintClauseKind = ({ o }: { o: ClauseKind }) => { diff --git a/ide/packages/print/src/private/syntax.css b/ide/packages/print/src/private/syntax.css index 1bfa01e..0c66fb8 100644 --- a/ide/packages/print/src/private/syntax.css +++ b/ide/packages/print/src/private/syntax.css @@ -4,4 +4,50 @@ span.kw { span.placeholder { color: var(--vscode-input-placeholderForeground); +} + +span.stx-wrapper.angles::before { + content: "<"; +} +span.stx-wrapper.angles::after { + content: ">"; +} + +span.stx-wrapper.parens::before { + content: "("; +} +span.stx-wrapper.parens::after { + content: ")"; +} + + +span.stx-wrapper.dbracket::before { + content: "{{"; +} +span.stx-wrapper.dbracket::after { + content: "}}"; +} + + +span.stx-wrapper.bracket::before { + content: "{"; +} +span.stx-wrapper.bracket::after { + content: "}"; +} + +span.stx-wrapper.sqbracket::before { + content: "["; +} +span.stx-wrapper.sqbracket::after { + content: "]"; +} + +.interspersed-list > span.comma:not(:last-child, :empty):after { + content: ',\00a0'; +} + +.interspersed-list > span.plus:not(:first-child, :empty):before { + /* Use NBSP for second space, to keep this aligned with the subsequent element */ + content: '\00a0+\00a0'; } \ No newline at end of file diff --git a/ide/packages/print/src/private/syntax.tsx b/ide/packages/print/src/private/syntax.tsx index 2387888..5a15be8 100644 --- a/ide/packages/print/src/private/syntax.tsx +++ b/ide/packages/print/src/private/syntax.tsx @@ -2,22 +2,59 @@ import _ from "lodash"; import React from "react"; import "./syntax.css"; +import classNames from "classnames"; -// A "Discretionary Space", hopefully this allows the layout to break along -// these elements rather than in the middle of text or random spaces. -export const Dsp = ({ children }: React.PropsWithChildren) => ( - {children} -); +export const nbsp = "\u00A0"; + +// A "Discretionary Space", the `inline-block` style helps format around these elements +// and breaks less between them and in random spaces. +export const Dsp = ( + props: React.PropsWithChildren & React.HTMLAttributes +) => { + const kids = props.children; + const htmlAttrs: React.HTMLAttributes = { + ...props, + children: undefined + }; + return ( + + {kids} + + ); +}; +/** + * Highlight the children as placeholders, this means they aren't concrete types. + * + * For Argus, this usually means changing the foreground to something softer. + */ export const Placeholder = ({ children }: React.PropsWithChildren) => ( {children} ); +/** + * Highlight the children as Rust keywords + */ export const Kw = ({ children }: React.PropsWithChildren) => ( {children} ); -const makeWrapper = +/** + * Create a wrapper around the children using a `stx-wrapper` class and the + * additional class `c`. This makes a wrapper that breakes around the wrapped + * elements. + */ +const makeCSSWrapper = + (c: string) => + ({ children }: React.PropsWithChildren) => ( + {children} + ); + +/** + * Create a wrapper that breaks around the children, but allows the `LHS` and `RHS` + * wrapping elements to split from their children. + */ +const makeBreakingWrapper = (lhs: string, rhs: string) => ({ children }: React.PropsWithChildren) => ( <> @@ -27,34 +64,47 @@ const makeWrapper = ); -export const Angled = makeWrapper("<", ">"); -export const DBraced = makeWrapper("{{", "}}"); -export const CBraced = makeWrapper("{", "}"); -export const Parenthesized = makeWrapper("(", ")"); -export const SqBraced = makeWrapper("[", "]"); +// We want content to break around parens and angle brackets. +// E.g., `fn foo(a: A, b: B) -> B` could be formatted as: +// ``` +// fn foo< +// A, B +// >( +// a: A, +// b: B +// ) -> B +// ``` +export const Angled = makeBreakingWrapper("<", ">"); +export const Parenthesized = makeBreakingWrapper("(", ")"); -export const CommaSeparated = ({ components }: { components: React.FC[] }) => ( - +export const DBraced = makeCSSWrapper("dbracket"); +export const CBraced = makeCSSWrapper("bracket"); +export const SqBraced = makeCSSWrapper("sqbracket"); + +export const CommaSeparated = ({ + components +}: { components: React.ReactElement[] }) => ( + ); -export const PlusSeparated = ({ components }: { components: React.FC[] }) => ( - +export const PlusSeparated = ({ + components +}: { components: React.ReactElement[] }) => ( + ); const Interspersed = ({ components, sep }: { - components: React.FC[]; + components: React.ReactElement[]; sep: string; -}) => - _.map(components, (C, i) => ( - // The inline-block span should help the layout to break on the elements - // and not in them. Still undecided if this actually does anything. - - {i === 0 ? "" : sep} - - - - - )); +}) => ( + + {_.map(components, (C, i) => ( + + {C} + + ))} + +); diff --git a/ide/packages/print/src/private/term.tsx b/ide/packages/print/src/private/term.tsx index 3b4b499..1f531b3 100644 --- a/ide/packages/print/src/private/term.tsx +++ b/ide/packages/print/src/private/term.tsx @@ -63,7 +63,7 @@ export const PrintExpr = ({ o }: { o: ExprDef }) => { } if ("FunctionCall" in o) { const [callable, args] = o.FunctionCall; - const argEs = _.map(args, arg => () => ); + const argEs = _.map(args, arg => ); return ( <> ( @@ -212,7 +212,7 @@ export const PrintValueTree = ({ o }: { o: ValTree }) => { }; const PrintAggregateArray = ({ fields }: { fields: Const[] }) => { - const components = _.map(fields, field => () => ); + const components = _.map(fields, field => ); return ( @@ -221,7 +221,7 @@ const PrintAggregateArray = ({ fields }: { fields: Const[] }) => { }; const PrintAggregateTuple = ({ fields }: { fields: Const[] }) => { - const components = _.map(fields, field => () => ); + const components = _.map(fields, field => ); const trailingComma = fields.length === 1 ? "," : null; return ( @@ -248,7 +248,7 @@ const PrintAggregateAdt = ({ switch (kind.type) { case "Fn": { const head = ; - const components = _.map(fields, field => () => ); + const components = _.map(fields, field => ); return ( <> {head} @@ -263,15 +263,11 @@ const PrintAggregateAdt = ({ return null; } case "Misc": { - const components = _.map( - _.zip(kind.names, fields), - ([name, field]) => - () => ( - <> - : - - ) - ); + const components = _.map(_.zip(kind.names, fields), ([name, field]) => ( + <> + : + + )); return ( diff --git a/ide/packages/print/src/private/ty.tsx b/ide/packages/print/src/private/ty.tsx index bdbf2b7..a9db7d8 100644 --- a/ide/packages/print/src/private/ty.tsx +++ b/ide/packages/print/src/private/ty.tsx @@ -4,8 +4,11 @@ import type { AliasTy, AliasTyKind, AssocItem, + BoundRegionKind, BoundTy, + BoundTyKind, BoundVariable, + BoundVariableKind, CoroutineClosureTyKind, CoroutineTyKind, CoroutineWitnessTyKind, @@ -33,7 +36,12 @@ import type { TypeAndMut, UintTy } from "@argus/common/bindings"; -import { anyElems, fnInputsAndOutput, tyIsUnit } from "@argus/common/func"; +import { + anyElems, + fnInputsAndOutput, + isNamedBoundVariable, + isUnitTy +} from "@argus/common/func"; import {} from "@floating-ui/react"; import _, { isObject } from "lodash"; import React, { useContext } from "react"; @@ -50,18 +58,45 @@ import { Parenthesized, Placeholder, PlusSeparated, - SqBraced + SqBraced, + nbsp } from "./syntax"; import { PrintTerm } from "./term"; -export const PrintBinder = ({ +interface Binding { + value: T; + boundVars: BoundVariableKind[]; +} + +export const PrintBinder = ({ binder, - innerF + Child }: { - binder: any; - innerF: any; + binder: Binding; + // FIXME: shouldn't this just be `React.FC`?? Doesn't typecheck though... + Child: React.FC<{ value: T }>; }) => { - return innerF(binder.value); + const components = _.map( + _.filter(binder.boundVars, isNamedBoundVariable), + v => + ); + + const b = + components.length === 0 ? null : ( + <> + for + + + + {nbsp} + + ); + return ( + <> + {b} + + + ); }; export const PrintTy = ({ o }: { o: Ty }) => { @@ -183,7 +218,7 @@ export const PrintTyKind = ({ o }: { o: TyKind }) => { return ; } if ("Tuple" in o) { - const components = _.map(o.Tuple, t => () => ); + const components = _.map(o.Tuple, t => ); return ( @@ -269,7 +304,7 @@ export const PrintPolyExistentialPredicates = ({ o: PolyExistentialPredicates; }) => { const head = o.data === undefined ? null : ; - const components = _.map(o.autoTraits, t => () => ); + const components = _.map(o.autoTraits, t => ); return ( <> {head} @@ -332,13 +367,15 @@ export const PrintPolyFnSig = ({ o }: { o: PolyFnSig }) => { cVariadic: boolean; }) => { const tyCtx = useContext(TyCtxt)!; - const inputComponents = _.map(inputs, ty => () => ); + const inputComponents = _.map(inputs, ty => ); const variadic = !cVariadic ? null : inputs.length === 0 ? "..." : ", ..."; const outVal = tyCtx.interner[output]; - const ret = tyIsUnit(outVal) ? null : ( + const ret = isUnitTy(outVal) ? null : ( <> - {" "} - {"->"} + {nbsp} + {"->"} + {nbsp} + ); return ( @@ -385,20 +422,24 @@ export const PrintPolyFnSig = ({ o }: { o: PolyFnSig }) => { } }; - const inner = (o: FnSig) => { - const unsafetyStr = o.safety === "Unsafe" ? "unsafe " : null; - const abi = ; - const [inputs, output] = fnInputsAndOutput(o.inputs_and_output); + const Inner = ({ value }: { value: FnSig }) => { + const unsafetyStr = value.safety === "Unsafe" ? "unsafe " : null; + const abi = ; + const [inputs, output] = fnInputsAndOutput(value.inputs_and_output); return ( <> {unsafetyStr} {abi}fn - + ); }; - return ; + return ; }; export const PrintFnDef = ({ o }: { o: FnDef }) => { @@ -501,7 +542,8 @@ export const PrintRegion = ({ o }: { o: Region }) => { } case "Anonymous": { // TODO: maybe we don't want to print anonymous lifetimes? - return "'_"; + // return "'_"; + return null; } default: { throw new Error("Unknown region type", o); @@ -531,6 +573,43 @@ export const PrintBoundVariable = ({ o }: { o: BoundVariable }) => { throw new Error("Unknown bound variable", o); }; +export const PrintBoundTyKind = ({ o }: { o: BoundTyKind }) => { + if ("Anon" === o) { + return null; + } else if ("Param" in o) { + const [name] = o.Param; + return ; + } + + throw new Error("Unknown bound ty kind", o); +}; + +export const PrintBoundVariableKind = ({ o }: { o: BoundVariableKind }) => { + if ("Const" === o) { + // TODO: not sure what to do with boudn "consts", we don't have data for them. + return null; + } else if ("Ty" in o) { + return ; + } else if ("Region" in o) { + return ; + } + + throw new Error("Unknown bound variable kind", o); +}; + +export const PrintBoundRegionKind = ({ o }: { o: BoundRegionKind }) => { + // TODO: what do we do in these cases? + if ("BrAnon" === o) { + return null; + } else if ("BrEnv" === o) { + return null; + } + if ("BrNamed" in o && o.BrNamed[0] !== "'_") { + const [name] = o.BrNamed; + return ; + } +}; + export const PrintPolarity = ({ o }: { o: Polarity }) => { return o === "Negative" ? "!" : o === "Maybe" ? "?" : null; }; @@ -539,7 +618,7 @@ export const PrintOpaqueImplType = ({ o }: { o: OpaqueImpl }) => { console.debug("Printing OpaqueImplType", o); const PrintFnTrait = ({ o }: { o: FnTrait }) => { - const args = _.map(o.params, param => () => ); + const args = _.map(o.params, param => ); const ret = o.retTy !== undefined ? ( <> @@ -571,10 +650,8 @@ export const PrintOpaqueImplType = ({ o }: { o: OpaqueImpl }) => { console.debug("Printing Trait", o); const prefix = ; const name = ; - const ownArgs = _.map(o.ownArgs, arg => () => ); - const assocArgs = _.map(o.assocArgs, arg => () => ( - - )); + const ownArgs = _.map(o.ownArgs, arg => ); + const assocArgs = _.map(o.assocArgs, arg => ); const argComponents = [...ownArgs, ...assocArgs]; const list = argComponents.length === 0 ? null : ( @@ -591,9 +668,9 @@ export const PrintOpaqueImplType = ({ o }: { o: OpaqueImpl }) => { ); }; - const fnTraits = _.map(o.fnTraits, trait => () => ); - const traits = _.map(o.traits, trait => () => ); - const lifetimes = _.map(o.lifetimes, lifetime => () => ( + const fnTraits = _.map(o.fnTraits, trait => ); + const traits = _.map(o.traits, trait => ); + const lifetimes = _.map(o.lifetimes, lifetime => ( )); const implComponents = _.concat(fnTraits, traits, lifetimes);