Skip to content

Commit

Permalink
Rollup merge of rust-lang#125968 - BoxyUwU:shrink_ty_expr, r=oli-obk
Browse files Browse the repository at this point in the history
Store the types of `ty::Expr` arguments in the `ty::Expr`

Part of rust-lang#125958

In attempting to remove the `ty` field on `Const` it will become necessary to store the `Ty<'tcx>` inside of `Expr<'tcx>`. In order to do this without blowing up the size of `ConstKind`, we start storing the type/const args as `GenericArgs`

r? `@oli-obk`
  • Loading branch information
compiler-errors authored Jun 4, 2024
2 parents 288727e + 67a73f2 commit a5dc684
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 165 deletions.
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
pub type UnevaluatedConst<'tcx> = ir::UnevaluatedConst<TyCtxt<'tcx>>;

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(ConstKind<'_>, 32);
rustc_data_structures::static_assert_size!(ConstKind<'_>, 24);

/// Use this rather than `ConstData`, whenever possible.
#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable)]
Expand Down Expand Up @@ -58,7 +58,7 @@ pub struct ConstData<'tcx> {
}

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(ConstData<'_>, 40);
rustc_data_structures::static_assert_size!(ConstData<'_>, 32);

impl<'tcx> Const<'tcx> {
#[inline]
Expand Down
125 changes: 118 additions & 7 deletions compiler/rustc_middle/src/ty/consts/kind.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::Const;
use crate::mir;
use crate::ty::abstract_const::CastKind;
use crate::ty::{self, visit::TypeVisitableExt as _, List, Ty, TyCtxt};
use crate::ty::{self, visit::TypeVisitableExt as _, Ty, TyCtxt};
use rustc_macros::{extension, HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};

#[extension(pub(crate) trait UnevaluatedConstEvalExt<'tcx>)]
Expand Down Expand Up @@ -40,14 +40,125 @@ impl<'tcx> ty::UnevaluatedConst<'tcx> {
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub enum ExprKind {
Binop(mir::BinOp),
UnOp(mir::UnOp),
FunctionCall,
Cast(CastKind),
}
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub enum Expr<'tcx> {
Binop(mir::BinOp, Const<'tcx>, Const<'tcx>),
UnOp(mir::UnOp, Const<'tcx>),
FunctionCall(Const<'tcx>, &'tcx List<Const<'tcx>>),
Cast(CastKind, Const<'tcx>, Ty<'tcx>),
pub struct Expr<'tcx> {
pub kind: ExprKind,
args: ty::GenericArgsRef<'tcx>,
}
impl<'tcx> Expr<'tcx> {
pub fn new_binop(
tcx: TyCtxt<'tcx>,
binop: mir::BinOp,
lhs_ty: Ty<'tcx>,
rhs_ty: Ty<'tcx>,
lhs_ct: Const<'tcx>,
rhs_ct: Const<'tcx>,
) -> Self {
let args = tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>(
[lhs_ty.into(), rhs_ty.into(), lhs_ct.into(), rhs_ct.into()].into_iter(),
);

Self { kind: ExprKind::Binop(binop), args }
}

pub fn binop_args(self) -> (Ty<'tcx>, Ty<'tcx>, Const<'tcx>, Const<'tcx>) {
assert!(matches!(self.kind, ExprKind::Binop(_)));

match self.args().as_slice() {
[lhs_ty, rhs_ty, lhs_ct, rhs_ct] => (
lhs_ty.expect_ty(),
rhs_ty.expect_ty(),
lhs_ct.expect_const(),
rhs_ct.expect_const(),
),
_ => bug!("Invalid args for `Binop` expr {self:?}"),
}
}

pub fn new_unop(tcx: TyCtxt<'tcx>, unop: mir::UnOp, ty: Ty<'tcx>, ct: Const<'tcx>) -> Self {
let args =
tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>([ty.into(), ct.into()].into_iter());

Self { kind: ExprKind::UnOp(unop), args }
}

pub fn unop_args(self) -> (Ty<'tcx>, Const<'tcx>) {
assert!(matches!(self.kind, ExprKind::UnOp(_)));

match self.args().as_slice() {
[ty, ct] => (ty.expect_ty(), ct.expect_const()),
_ => bug!("Invalid args for `UnOp` expr {self:?}"),
}
}

pub fn new_call(
tcx: TyCtxt<'tcx>,
func_ty: Ty<'tcx>,
func_expr: Const<'tcx>,
arguments: impl Iterator<Item = Const<'tcx>>,
) -> Self {
let args = tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>(
[func_ty.into(), func_expr.into()].into_iter().chain(arguments.map(|ct| ct.into())),
);

Self { kind: ExprKind::FunctionCall, args }
}

pub fn call_args(self) -> (Ty<'tcx>, Const<'tcx>, impl Iterator<Item = Const<'tcx>>) {
assert!(matches!(self.kind, ExprKind::FunctionCall));

match self.args().as_slice() {
[func_ty, func, rest @ ..] => (
func_ty.expect_ty(),
func.expect_const(),
rest.iter().map(|arg| arg.expect_const()),
),
_ => bug!("Invalid args for `Call` expr {self:?}"),
}
}

pub fn new_cast(
tcx: TyCtxt<'tcx>,
cast: CastKind,
value_ty: Ty<'tcx>,
value: Const<'tcx>,
to_ty: Ty<'tcx>,
) -> Self {
let args = tcx.mk_args_from_iter::<_, ty::GenericArg<'tcx>>(
[value_ty.into(), value.into(), to_ty.into()].into_iter(),
);

Self { kind: ExprKind::Cast(cast), args }
}

pub fn cast_args(self) -> (Ty<'tcx>, Const<'tcx>, Ty<'tcx>) {
assert!(matches!(self.kind, ExprKind::Cast(_)));

match self.args().as_slice() {
[value_ty, value, to_ty] => {
(value_ty.expect_ty(), value.expect_const(), to_ty.expect_ty())
}
_ => bug!("Invalid args for `Cast` expr {self:?}"),
}
}

pub fn new(kind: ExprKind, args: ty::GenericArgsRef<'tcx>) -> Self {
Self { kind, args }
}

pub fn args(&self) -> ty::GenericArgsRef<'tcx> {
self.args
}
}

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(Expr<'_>, 24);
rustc_data_structures::static_assert_size!(Expr<'_>, 16);
21 changes: 1 addition & 20 deletions compiler/rustc_middle/src/ty/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,26 +374,7 @@ impl FlagComputation {
self.add_flags(TypeFlags::STILL_FURTHER_SPECIALIZABLE);
}
ty::ConstKind::Value(_) => {}
ty::ConstKind::Expr(e) => {
use ty::Expr;
match e {
Expr::Binop(_, l, r) => {
self.add_const(l);
self.add_const(r);
}
Expr::UnOp(_, v) => self.add_const(v),
Expr::FunctionCall(f, args) => {
self.add_const(f);
for arg in args {
self.add_const(arg);
}
}
Expr::Cast(_, c, t) => {
self.add_ty(t);
self.add_const(c);
}
}
}
ty::ConstKind::Expr(e) => self.add_args(e.args()),
ty::ConstKind::Error(_) => self.add_flags(TypeFlags::HAS_ERROR),
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub use self::closure::{
CAPTURE_STRUCT_LOCAL,
};
pub use self::consts::{
Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree,
Const, ConstData, ConstInt, ConstKind, Expr, ExprKind, ScalarInt, UnevaluatedConst, ValTree,
};
pub use self::context::{
tls, CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift,
Expand Down
102 changes: 44 additions & 58 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1533,8 +1533,10 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
print_ty: bool,
) -> Result<(), PrintError> {
define_scoped_cx!(self);
match expr {
Expr::Binop(op, c1, c2) => {
match expr.kind {
ty::ExprKind::Binop(op) => {
let (_, _, c1, c2) = expr.binop_args();

let precedence = |binop: rustc_middle::mir::BinOp| {
use rustc_ast::util::parser::AssocOp;
AssocOp::from_ast_binop(binop.to_hir_binop().into()).precedence()
Expand All @@ -1543,22 +1545,26 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let formatted_op = op.to_hir_binop().as_str();
let (lhs_parenthesized, rhs_parenthesized) = match (c1.kind(), c2.kind()) {
(
ty::ConstKind::Expr(Expr::Binop(lhs_op, _, _)),
ty::ConstKind::Expr(Expr::Binop(rhs_op, _, _)),
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(lhs_op), .. }),
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(rhs_op), .. }),
) => (precedence(lhs_op) < op_precedence, precedence(rhs_op) < op_precedence),
(ty::ConstKind::Expr(Expr::Binop(lhs_op, ..)), ty::ConstKind::Expr(_)) => {
(precedence(lhs_op) < op_precedence, true)
}
(ty::ConstKind::Expr(_), ty::ConstKind::Expr(Expr::Binop(rhs_op, ..))) => {
(true, precedence(rhs_op) < op_precedence)
}
(
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(lhs_op), .. }),
ty::ConstKind::Expr(_),
) => (precedence(lhs_op) < op_precedence, true),
(
ty::ConstKind::Expr(_),
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(rhs_op), .. }),
) => (true, precedence(rhs_op) < op_precedence),
(ty::ConstKind::Expr(_), ty::ConstKind::Expr(_)) => (true, true),
(ty::ConstKind::Expr(Expr::Binop(lhs_op, ..)), _) => {
(precedence(lhs_op) < op_precedence, false)
}
(_, ty::ConstKind::Expr(Expr::Binop(rhs_op, ..))) => {
(false, precedence(rhs_op) < op_precedence)
}
(
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(lhs_op), .. }),
_,
) => (precedence(lhs_op) < op_precedence, false),
(
_,
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::Binop(rhs_op), .. }),
) => (false, precedence(rhs_op) < op_precedence),
(ty::ConstKind::Expr(_), _) => (true, false),
(_, ty::ConstKind::Expr(_)) => (false, true),
_ => (false, false),
Expand All @@ -1574,7 +1580,9 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
rhs_parenthesized,
)?;
}
Expr::UnOp(op, ct) => {
ty::ExprKind::UnOp(op) => {
let (_, ct) = expr.unop_args();

use rustc_middle::mir::UnOp;
let formatted_op = match op {
UnOp::Not => "!",
Expand All @@ -1583,7 +1591,9 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
};
let parenthesized = match ct.kind() {
_ if op == UnOp::PtrMetadata => true,
ty::ConstKind::Expr(Expr::UnOp(c_op, ..)) => c_op != op,
ty::ConstKind::Expr(ty::Expr { kind: ty::ExprKind::UnOp(c_op), .. }) => {
c_op != op
}
ty::ConstKind::Expr(_) => true,
_ => false,
};
Expand All @@ -1593,61 +1603,37 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
parenthesized,
)?
}
Expr::FunctionCall(fn_def, fn_args) => {
use ty::TyKind;
match fn_def.ty().kind() {
TyKind::FnDef(def_id, gen_args) => {
p!(print_value_path(*def_id, gen_args), "(");
if print_ty {
let tcx = self.tcx();
let sig = tcx.fn_sig(def_id).instantiate(tcx, gen_args).skip_binder();

let mut args_with_ty = fn_args.iter().map(|ct| (ct, ct.ty()));
let output_ty = sig.output();

if let Some((ct, ty)) = args_with_ty.next() {
self.typed_value(
|this| this.pretty_print_const(ct, print_ty),
|this| this.pretty_print_type(ty),
": ",
)?;
for (ct, ty) in args_with_ty {
p!(", ");
self.typed_value(
|this| this.pretty_print_const(ct, print_ty),
|this| this.pretty_print_type(ty),
": ",
)?;
}
}
p!(write(") -> {output_ty}"));
} else {
p!(comma_sep(fn_args.iter()), ")");
}
}
_ => bug!("unexpected type of fn def"),
}
ty::ExprKind::FunctionCall => {
let (_, fn_def, fn_args) = expr.call_args();

write!(self, "(")?;
self.pretty_print_const(fn_def, print_ty)?;
p!(")(", comma_sep(fn_args), ")");
}
Expr::Cast(kind, ct, ty) => {
ty::ExprKind::Cast(kind) => {
let (_, value, to_ty) = expr.cast_args();

use ty::abstract_const::CastKind;
if kind == CastKind::As || (kind == CastKind::Use && self.should_print_verbose()) {
let parenthesized = match ct.kind() {
ty::ConstKind::Expr(Expr::Cast(_, _, _)) => false,
let parenthesized = match value.kind() {
ty::ConstKind::Expr(ty::Expr {
kind: ty::ExprKind::Cast { .. }, ..
}) => false,
ty::ConstKind::Expr(_) => true,
_ => false,
};
self.maybe_parenthesized(
|this| {
this.typed_value(
|this| this.pretty_print_const(ct, print_ty),
|this| this.pretty_print_type(ty),
|this| this.pretty_print_const(value, print_ty),
|this| this.pretty_print_type(to_ty),
" as ",
)
},
parenthesized,
)?;
} else {
self.pretty_print_const(ct, print_ty)?
self.pretty_print_const(value, print_ty)?
}
}
}
Expand Down
Loading

0 comments on commit a5dc684

Please sign in to comment.