Skip to content

Commit

Permalink
make enum size not depend on the order of variants
Browse files Browse the repository at this point in the history
  • Loading branch information
adwinwhite committed Oct 14, 2024
1 parent c0838c8 commit 7af5de3
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 140 deletions.
16 changes: 12 additions & 4 deletions compiler/rustc_abi/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
pub fn layout_of_struct_or_enum<
'a,
FieldIdx: Idx,
VariantIdx: Idx,
VariantIdx: Idx + PartialOrd,
F: Deref<Target = &'a LayoutS<FieldIdx, VariantIdx>> + fmt::Debug + Copy,
>(
&self,
Expand Down Expand Up @@ -464,7 +464,7 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
fn layout_of_enum<
'a,
FieldIdx: Idx,
VariantIdx: Idx,
VariantIdx: Idx + PartialOrd,
F: Deref<Target = &'a LayoutS<FieldIdx, VariantIdx>> + fmt::Debug + Copy,
>(
&self,
Expand Down Expand Up @@ -524,8 +524,16 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
let niche_variants = all_indices.clone().find(|v| needs_disc(*v)).unwrap()
..=all_indices.rev().find(|v| needs_disc(*v)).unwrap();

let count =
(niche_variants.end().index() as u128 - niche_variants.start().index() as u128) + 1;
let count = {
let niche_variants_len = (niche_variants.end().index() as u128
- niche_variants.start().index() as u128)
+ 1;
if niche_variants.contains(&largest_variant_index) {
niche_variants_len - 1
} else {
niche_variants_len
}
};

// Use the largest niche in the largest variant.
let niche = variant_layouts[largest_variant_index].largest_niche?;
Expand Down
56 changes: 50 additions & 6 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1475,15 +1475,59 @@ pub enum TagEncoding<VariantIdx: Idx> {
Direct,

/// Niche (values invalid for a type) encoding the discriminant:
/// Discriminant and variant index coincide.
/// Discriminant and variant index doesn't always coincide.
///
/// The variant `untagged_variant` contains a niche at an arbitrary
/// offset (field `tag_field` of the enum), which for a variant with
/// discriminant `d` is set to
/// `(d - niche_variants.start).wrapping_add(niche_start)`.
/// discriminant `d` is set to `d.wrapping_add(niche_start)`.
///
/// For example, `Option<(usize, &T)>` is represented such that
/// `None` has a null pointer for the second tuple field, and
/// `Some` is the identity function (with a non-null reference).
/// As for how to compute the discriminant, we have an optimization here that we allocate discriminant
/// value starting from the variant after the `untagged_variant` when the `untagged_variant` is
/// contained in `niche_variants`' range. Thus the `untagged_variant` won't be allocated with a
/// unneeded discriminant. Motivation for this is issue #117238.
/// For example,
/// ```rust
/// enum {
/// A, // 1
/// B, // 2
/// C(bool), // untagged_variant, no discriminant
/// D, // has a discriminant of 0
/// }
/// ```
/// The algorithm is as follows:
/// ```rust
/// // We ignore leading and trailing variants that don't need discriminants.
/// adjusted_len = niche_variants.end - niche_variants.start + 1
/// adjusted_index = variant_index - niche_variants.start
/// d = if niche_variants.contains(untagged_variant) {
/// adjusted_untagged_index = untagged_variant - niche_variants.start
/// (adjusted_index + adjusted_len - adjusted_untagged_index) % adjusted_len - 1
/// } else {
/// adjusted_index
/// }
/// tag_value = d.wrapping_add(niche_start)
/// ```
/// To load variant index from tag value:
/// ```rust
/// adjusted_len = niche_variants.end - niche_variants.start + 1
/// d = tag_value.wrapping_sub(niche_start)
/// variant_index = if niche_variants.contains(untagged_variant) {
/// if d < adjusted_len - 1 {
/// adjusted_untagged_index = untagged_variant - niche_variants.start
/// (d + 1 + adjusted_untagged_index) % adjusted_len + niche_variants.start
/// } else {
/// // When the discriminant is larger than the number of variants having
/// // discriminant, we know it represents the untagged_variant.
/// untagged_variant
/// }
/// } else {
/// if d < adjusted_len {
/// d + niche_variants.start
/// } else {
/// untagged_variant
/// }
/// }
/// ```
Niche {
untagged_variant: VariantIdx,
niche_variants: RangeInclusive<VariantIdx>,
Expand Down
169 changes: 105 additions & 64 deletions compiler/rustc_codegen_cranelift/src/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,20 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
variants: _,
} => {
if variant_index != untagged_variant {
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;
let adj_idx = variant_index.index() - niche_variants.start().index();

let niche = place.place_field(fx, FieldIdx::new(tag_field));
let niche_type = fx.clif_type(niche.layout().ty).unwrap();
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
let niche_value = (niche_value as u128).wrapping_add(niche_start);

let discr = if niche_variants.contains(&untagged_variant) {
let adj_untagged_idx =
untagged_variant.index() - niche_variants.start().index();
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1;
} else {
adj_idx
};
let niche_value = (discr as u128).wrapping_add(niche_start);
let niche_value = match niche_type {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_value as u64 as i64);
Expand Down Expand Up @@ -130,72 +140,103 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
dest.write_cvalue(fx, res);
}
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
// "niche value", an increment of the discriminant. Otherwise it
// indicates the untagged variant.
// A general algorithm to extract the discriminant from the tag
// is:
// relative_tag = tag - niche_start
// is_niche = relative_tag <= (ule) relative_max
// discr = if is_niche {
// cast(relative_tag) + niche_variants.start()
// } else {
// untagged_variant
// }
// However, we will likely be able to emit simpler code.

let (is_niche, tagged_discr, delta) = if relative_max == 0 {
// Best case scenario: only one tagged variant. This will
// likely become just a comparison and a jump.
// The algorithm is:
// is_niche = tag == niche_start
// discr = if is_niche {
// niche_start
// } else {
// untagged_variant
// }
let is_niche = codegen_icmp_imm(fx, IntCC::Equal, tag, niche_start as i128);
// See the algorithm explanation in the definition of `TagEncoding::Niche`.
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;

let niche_start = match fx.bcx.func.dfg.value_type(tag) {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
let msb = fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
};

let (is_niche, tagged_discr) = if discr_len == 1 {
// Special case where we only have a single tagged variant.
// The untagged variant can't be contained in niche_variant's range in this case.
// Thus the discriminant of the only tagged variant is 0 and its variant index
// is the start of niche_variants.
let is_niche = codegen_icmp_imm(fx, IntCC::Equal, tag, niche_start);
let tagged_discr =
fx.bcx.ins().iconst(cast_to, niche_variants.start().as_u32() as i64);
(is_niche, tagged_discr, 0)
(is_niche, tagged_discr)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let niche_start = match fx.bcx.func.dfg.value_type(tag) {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
let msb =
fx.bcx.ins().iconst(types::I64, (niche_start >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
};
let relative_discr = fx.bcx.ins().isub(tag, niche_start);
let cast_tag = clif_intcast(fx, relative_discr, cast_to, false);
let is_niche = crate::common::codegen_icmp_imm(
fx,
IntCC::UnsignedLessThanOrEqual,
relative_discr,
i128::from(relative_max),
);
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
// General case.
let discr = fx.bcx.ins().isub(tag, niche_start);
let tagged_discr = clif_intcast(fx, relative_discr, cast_to, false);
if niche_variants.contains(&untagged_variant) {
let is_niche = crate::common::codegen_icmp_imm(
fx,
IntCC::UnsignedLessThan,
discr,
i128::from(discr_len - 1),
);
let adj_untagged_idx =
untagged_variant.index() - niche_variants.start().index();
let untagged_delta = 1 + adj_untagged_idx;
let untagged_delta = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, untagged_delta as u64 as i64);
let msb = fx
.bcx
.ins()
.iconst(types::I64, (untagged_delta >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, untagged_delta as i64),
};
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, untagged_delta);

let tagged_discr = if delta == 0 {
tagged_discr
} else {
let delta = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, delta as u64 as i64);
let msb = fx.bcx.ins().iconst(types::I64, (delta >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, delta as i64),
};
fx.bcx.ins().iadd(tagged_discr, delta)
let discr_len = match cast_to {
types::I128 => {
let lsb = fx.bcx.ins().iconst(types::I64, discr_len as u64 as i64);
let msb =
fx.bcx.ins().iconst(types::I64, (discr_len >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, discr_len as i64),
};
let tagged_discr = fx.bcx.ins().urem(tagged_discr, discr_len);

let niche_variants_start = niche_variants.start().index();
let niche_variants_start = match cast_to {
types::I128 => {
let lsb =
fx.bcx.ins().iconst(types::I64, niche_variants_start as u64 as i64);
let msb = fx
.bcx
.ins()
.iconst(types::I64, (niche_variants_start >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64),
};
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start);
(is_niche, tagged_discr)
} else {
let is_niche = crate::common::codegen_icmp_imm(
fx,
IntCC::UnsignedLessThan,
discr,
i128::from(discr_len - 1),
);
let niche_variants_start = niche_variants.start().index();
let niche_variants_start = match cast_to {
types::I128 => {
let lsb =
fx.bcx.ins().iconst(types::I64, niche_variants_start as u64 as i64);
let msb = fx
.bcx
.ins()
.iconst(types::I64, (niche_variants_start >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
}
ty => fx.bcx.ins().iconst(ty, niche_variants_start as i64),
};
let tagged_discr = fx.bcx.ins().iadd(tagged_discr, niche_variants_start);
(is_niche, tagged_discr)
}
};

let untagged_variant = if cast_to == types::I128 {
Expand Down
14 changes: 11 additions & 3 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,17 @@ fn compute_discriminant_value<'ll, 'tcx>(

DiscrResult::Range(min, max)
} else {
let value = (variant_index.as_u32() as u128)
.wrapping_sub(niche_variants.start().as_u32() as u128)
.wrapping_add(niche_start);
let discr_len = niche_variants.end().index() - niche_variants.start().index() + 1;
let adj_idx = variant_index.index() - niche_variants.start().index();

let discr = if niche_variants.contains(&untagged_variant) {
let adj_untagged_idx =
untagged_variant.index() - niche_variants.start().index();
(adj_idx + discr_len - adj_untagged_idx) % discr_len - 1
} else {
adj_idx
};
let value = (discr as u128).wrapping_add(niche_start);
let value = tag.size(cx).truncate(value);
DiscrResult::Value(value)
}
Expand Down
Loading

0 comments on commit 7af5de3

Please sign in to comment.