From 612faaab45df0a0a1c9bcae5fe4841a6c00b8786 Mon Sep 17 00:00:00 2001 From: Jakub Dupak Date: Thu, 5 Oct 2023 12:36:08 +0200 Subject: [PATCH] TyTy: use new subclass API gcc/rust/ChangeLog: * typecheck/rust-tyty.cc (BaseType::is_unit): Refactor. (BaseType::satisfies_bound): Refactor. (BaseType::get_root): Refactor. (BaseType::destructure): Refactor. (BaseType::monomorphized_clone): Refactor. (BaseType::is_concrete): Refactor. (InferType::InferType): Refactor. (InferType::clone): Refactor. (InferType::apply_primitive_type_hint): Refactor. (StructFieldType::is_equal): Refactor. (ADTType::is_equal): Refactor. (handle_substitions): Refactor. (ADTType::handle_substitions): Refactor. (TupleType::TupleType): Refactor. (TupleType::is_equal): Refactor. (TupleType::handle_substitions): Refactor. Signed-off-by: Jakub Dupak --- gcc/rust/typecheck/rust-tyty.cc | 542 +++++++++++++------------------- 1 file changed, 226 insertions(+), 316 deletions(-) diff --git a/gcc/rust/typecheck/rust-tyty.cc b/gcc/rust/typecheck/rust-tyty.cc index da8710dd2078..64db06dd0785 100644 --- a/gcc/rust/typecheck/rust-tyty.cc +++ b/gcc/rust/typecheck/rust-tyty.cc @@ -220,16 +220,15 @@ BaseType::is_unit () const return true; case TUPLE: { - const TupleType &tuple = *static_cast (x); - return tuple.num_fields () == 0; + return x->as ()->num_fields () == 0; } case ADT: { - const ADTType &adt = *static_cast (x); - if (adt.is_enum ()) + auto adt = x->as (); + if (adt->is_enum ()) return false; - for (const auto &variant : adt.get_variants ()) + for (const auto &variant : adt->get_variants ()) { if (variant->num_fields () > 0) return false; @@ -276,8 +275,6 @@ bool BaseType::satisfies_bound (const TypeBoundPredicate &predicate, bool emit_error) const { - bool is_infer_var = destructure ()->get_kind () == TyTy::TypeKind::INFER; - const Resolver::TraitReference *query = predicate.get (); for (const auto &bound : specified_bounds) { @@ -286,7 +283,7 @@ BaseType::satisfies_bound (const TypeBoundPredicate &predicate, return true; } - if (is_infer_var) + if (destructure ()->is ()) return true; bool satisfied = false; @@ -435,28 +432,24 @@ BaseType::get_root () const { // FIXME this needs to be it its own visitor class with a vector adjustments const TyTy::BaseType *root = this; - if (get_kind () == TyTy::REF) + + if (const auto r = root->try_as ()) { - const ReferenceType *r = static_cast (root); root = r->get_base ()->get_root (); } - else if (get_kind () == TyTy::POINTER) + else if (const auto r = root->try_as ()) { - const PointerType *r = static_cast (root); root = r->get_base ()->get_root (); } - // these are an unsize - else if (get_kind () == TyTy::SLICE) + else if (const auto r = root->try_as ()) { - const SliceType *r = static_cast (root); root = r->get_element_type ()->get_root (); } - // else if (get_kind () == TyTy::ARRAY) - // { - // const ArrayType *r = static_cast (root); - // root = r->get_element_type ()->get_root (); - // } + // else if (const auto r = root->try_as ()) + // { + // root = r->get_element_type ()->get_root (); + // } return root; } @@ -478,34 +471,27 @@ BaseType::destructure () return new ErrorType (get_ref ()); } - switch (x->get_kind ()) + if (auto p = x->try_as ()) { - case TyTy::TypeKind::PARAM: { - TyTy::ParamType *p = static_cast (x); - TyTy::BaseType *pr = p->resolve (); - if (pr == x) - return pr; - - x = pr; - } - break; + auto pr = p->resolve (); + if (pr == x) + return pr; - case TyTy::TypeKind::PLACEHOLDER: { - TyTy::PlaceholderType *p = static_cast (x); - if (!p->can_resolve ()) - return p; - - x = p->resolve (); - } - break; - - case TyTy::TypeKind::PROJECTION: { - TyTy::ProjectionType *p = static_cast (x); - x = p->get (); - } - break; + x = pr; + } + else if (auto p = x->try_as ()) + { + if (!p->can_resolve ()) + return p; - default: + x = p->resolve (); + } + else if (auto p = x->try_as ()) + { + x = p->get (); + } + else + { return x; } } @@ -530,36 +516,27 @@ BaseType::destructure () const return new ErrorType (get_ref ()); } - switch (x->get_kind ()) + if (auto p = x->try_as ()) { - case TyTy::TypeKind::PARAM: { - const TyTy::ParamType *p = static_cast (x); - const TyTy::BaseType *pr = p->resolve (); - if (pr == x) - return pr; - - x = pr; - } - break; + auto pr = p->resolve (); + if (pr == x) + return pr; - case TyTy::TypeKind::PLACEHOLDER: { - const TyTy::PlaceholderType *p - = static_cast (x); - if (!p->can_resolve ()) - return p; - - x = p->resolve (); - } - break; - - case TyTy::TypeKind::PROJECTION: { - const TyTy::ProjectionType *p - = static_cast (x); - x = p->get (); - } - break; + x = pr; + } + else if (auto p = x->try_as ()) + { + if (!p->can_resolve ()) + return p; - default: + x = p->resolve (); + } + else if (auto p = x->try_as ()) + { + x = p->get (); + } + else + { return x; } } @@ -571,112 +548,81 @@ BaseType * BaseType::monomorphized_clone () const { const TyTy::BaseType *x = destructure (); - switch (x->get_kind ()) - { - case PARAM: - case PROJECTION: - case PLACEHOLDER: - case INFER: - case BOOL: - case CHAR: - case INT: - case UINT: - case FLOAT: - case USIZE: - case ISIZE: - case NEVER: - case STR: - case DYNAMIC: - case CLOSURE: - case ERROR: - return x->clone (); - - case ARRAY: { - const ArrayType &arr = *static_cast (x); - TyVar elm = arr.get_var_element_type ().monomorphized_clone (); - return new ArrayType (arr.get_ref (), arr.get_ty_ref (), ident.locus, - arr.get_capacity_expr (), elm, - arr.get_combined_refs ()); - } - break; - - case SLICE: { - const SliceType &slice = *static_cast (x); - TyVar elm = slice.get_var_element_type ().monomorphized_clone (); - return new SliceType (slice.get_ref (), slice.get_ty_ref (), - ident.locus, elm, slice.get_combined_refs ()); - } - break; - - case POINTER: { - const PointerType &ptr = *static_cast (x); - TyVar elm = ptr.get_var_element_type ().monomorphized_clone (); - return new PointerType (ptr.get_ref (), ptr.get_ty_ref (), elm, - ptr.mutability (), ptr.get_combined_refs ()); - } - break; - - case REF: { - const ReferenceType &ref = *static_cast (x); - TyVar elm = ref.get_var_element_type ().monomorphized_clone (); - return new ReferenceType (ref.get_ref (), ref.get_ty_ref (), elm, - ref.mutability (), ref.get_combined_refs ()); - } - break; - case TUPLE: { - const TupleType &tuple = *static_cast (x); - std::vector cloned_fields; - for (const auto &f : tuple.get_fields ()) - cloned_fields.push_back (f.monomorphized_clone ()); - - return new TupleType (tuple.get_ref (), tuple.get_ty_ref (), - tuple.get_ident ().locus, cloned_fields, - tuple.get_combined_refs ()); - } - break; - - case FNDEF: { - const FnType &fn = *static_cast (x); - std::vector> cloned_params; - for (auto &p : fn.get_params ()) - cloned_params.push_back ({p.first, p.second->monomorphized_clone ()}); - - BaseType *retty = fn.get_return_type ()->monomorphized_clone (); - return new FnType (fn.get_ref (), fn.get_ty_ref (), fn.get_id (), - fn.get_identifier (), fn.ident, fn.get_flags (), - fn.get_abi (), std::move (cloned_params), retty, - fn.clone_substs (), fn.get_combined_refs ()); - } - break; - - case FNPTR: { - const FnPtr &fn = *static_cast (x); - std::vector cloned_params; - for (auto &p : fn.get_params ()) - cloned_params.push_back (p.monomorphized_clone ()); - - TyVar retty = fn.get_var_return_type ().monomorphized_clone (); - return new FnPtr (fn.get_ref (), fn.get_ty_ref (), fn.ident.locus, - std::move (cloned_params), retty, - fn.get_combined_refs ()); - } - break; + if (auto arr = x->try_as ()) + { + TyVar elm = arr->get_var_element_type ().monomorphized_clone (); + return new ArrayType (arr->get_ref (), arr->get_ty_ref (), ident.locus, + arr->get_capacity_expr (), elm, + arr->get_combined_refs ()); + } + else if (auto slice = x->try_as ()) + { + TyVar elm = slice->get_var_element_type ().monomorphized_clone (); + return new SliceType (slice->get_ref (), slice->get_ty_ref (), + ident.locus, elm, slice->get_combined_refs ()); + } + else if (auto ptr = x->try_as ()) + { + TyVar elm = ptr->get_var_element_type ().monomorphized_clone (); + return new PointerType (ptr->get_ref (), ptr->get_ty_ref (), elm, + ptr->mutability (), ptr->get_combined_refs ()); + } + else if (auto ref = x->try_as ()) + { + TyVar elm = ref->get_var_element_type ().monomorphized_clone (); + return new ReferenceType (ref->get_ref (), ref->get_ty_ref (), elm, + ref->mutability (), ref->get_combined_refs ()); + } + else if (auto tuple = x->try_as ()) + { + std::vector cloned_fields; + for (const auto &f : tuple->get_fields ()) + cloned_fields.push_back (f.monomorphized_clone ()); - case ADT: { - const ADTType &adt = *static_cast (x); - std::vector cloned_variants; - for (auto &variant : adt.get_variants ()) - cloned_variants.push_back (variant->monomorphized_clone ()); - - return new ADTType (adt.get_ref (), adt.get_ty_ref (), - adt.get_identifier (), adt.ident, - adt.get_adt_kind (), cloned_variants, - adt.clone_substs (), adt.get_repr_options (), - adt.get_used_arguments (), - adt.get_combined_refs ()); - } - break; + return new TupleType (tuple->get_ref (), tuple->get_ty_ref (), + ident.locus, cloned_fields, + tuple->get_combined_refs ()); + } + else if (auto fn = x->try_as ()) + { + std::vector> cloned_params; + for (auto &p : fn->get_params ()) + cloned_params.push_back ({p.first, p.second->monomorphized_clone ()}); + + BaseType *retty = fn->get_return_type ()->monomorphized_clone (); + return new FnType (fn->get_ref (), fn->get_ty_ref (), fn->get_id (), + fn->get_identifier (), fn->ident, fn->get_flags (), + fn->get_abi (), std::move (cloned_params), retty, + fn->clone_substs (), fn->get_combined_refs ()); + } + else if (auto fn = x->try_as ()) + { + std::vector cloned_params; + for (auto &p : fn->get_params ()) + cloned_params.push_back (p.monomorphized_clone ()); + + TyVar retty = fn->get_var_return_type ().monomorphized_clone (); + return new FnPtr (fn->get_ref (), fn->get_ty_ref (), ident.locus, + std::move (cloned_params), retty, + fn->get_combined_refs ()); + } + else if (auto adt = x->try_as ()) + { + std::vector cloned_variants; + for (auto &variant : adt->get_variants ()) + cloned_variants.push_back (variant->monomorphized_clone ()); + + return new ADTType (adt->get_ref (), adt->get_ty_ref (), + adt->get_identifier (), adt->ident, + adt->get_adt_kind (), cloned_variants, + adt->clone_substs (), adt->get_repr_options (), + adt->get_used_arguments (), + adt->get_combined_refs ()); + } + else + { + return x->clone (); } rust_unreachable (); @@ -714,122 +660,94 @@ bool BaseType::is_concrete () const { const TyTy::BaseType *x = destructure (); - switch (x->get_kind ()) + + if (x->is () || x->is ()) { - case PARAM: - case PROJECTION: return false; - - // placeholder is a special case for this case when it is not resolvable - // it means we its just an empty placeholder associated type which is - // concrete - case PLACEHOLDER: + } + // placeholder is a special case for this case when it is not resolvable + // it means we its just an empty placeholder associated type which is + // concrete + else if (x->is ()) + { return true; + } + else if (auto fn = x->try_as ()) + { + for (const auto ¶m : fn->get_params ()) + { + if (!param.second->is_concrete ()) + return false; + } + return fn->get_return_type ()->is_concrete (); + } + else if (auto fn = x->try_as ()) + { + for (const auto ¶m : fn->get_params ()) + { + if (!param.get_tyty ()->is_concrete ()) + return false; + } + return fn->get_return_type ()->is_concrete (); + } + else if (auto adt = x->try_as ()) + { + if (adt->is_unit ()) + return !adt->needs_substitution (); - case FNDEF: { - const FnType &fn = *static_cast (x); - for (const auto ¶m : fn.get_params ()) - { - const BaseType *p = param.second; - if (!p->is_concrete ()) - return false; - } - return fn.get_return_type ()->is_concrete (); - } - break; - - case FNPTR: { - const FnPtr &fn = *static_cast (x); - for (const auto ¶m : fn.get_params ()) - { - const BaseType *p = param.get_tyty (); - if (!p->is_concrete ()) - return false; - } - return fn.get_return_type ()->is_concrete (); - } - break; - - case ADT: { - const ADTType &adt = *static_cast (x); - if (adt.is_unit ()) - { - return !adt.needs_substitution (); - } - - for (auto &variant : adt.get_variants ()) - { - bool is_num_variant - = variant->get_variant_type () == VariantDef::VariantType::NUM; - if (is_num_variant) - continue; - - for (auto &field : variant->get_fields ()) - { - const BaseType *field_type = field->get_field_type (); - if (!field_type->is_concrete ()) - return false; - } - } - return true; - } - break; - - case ARRAY: { - const ArrayType &arr = *static_cast (x); - return arr.get_element_type ()->is_concrete (); - } - break; - - case SLICE: { - const SliceType &slice = *static_cast (x); - return slice.get_element_type ()->is_concrete (); - } - break; - - case POINTER: { - const PointerType &ptr = *static_cast (x); - return ptr.get_base ()->is_concrete (); - } - break; - - case REF: { - const ReferenceType &ref = *static_cast (x); - return ref.get_base ()->is_concrete (); - } - break; - - case TUPLE: { - const TupleType &tuple = *static_cast (x); - for (size_t i = 0; i < tuple.num_fields (); i++) - { - if (!tuple.get_field (i)->is_concrete ()) - return false; - } - return true; - } - break; - - case CLOSURE: { - const ClosureType &closure = *static_cast (x); - if (closure.get_parameters ().is_concrete ()) - return false; - return closure.get_result_type ().is_concrete (); - } - break; + for (auto &variant : adt->get_variants ()) + { + bool is_num_variant + = variant->get_variant_type () == VariantDef::VariantType::NUM; + if (is_num_variant) + continue; - case INFER: - case BOOL: - case CHAR: - case INT: - case UINT: - case FLOAT: - case USIZE: - case ISIZE: - case NEVER: - case STR: - case DYNAMIC: - case ERROR: + for (auto &field : variant->get_fields ()) + { + const BaseType *field_type = field->get_field_type (); + if (!field_type->is_concrete ()) + return false; + } + } + return true; + } + else if (auto arr = x->try_as ()) + { + return arr->get_element_type ()->is_concrete (); + } + else if (auto slice = x->try_as ()) + { + return slice->get_element_type ()->is_concrete (); + } + else if (auto ptr = x->try_as ()) + { + return ptr->get_base ()->is_concrete (); + } + else if (auto ref = x->try_as ()) + { + return ref->get_base ()->is_concrete (); + } + else if (auto tuple = x->try_as ()) + { + for (size_t i = 0; i < tuple->num_fields (); i++) + { + if (!tuple->get_field (i)->is_concrete ()) + return false; + } + return true; + } + else if (auto closure = x->try_as ()) + { + if (closure->get_parameters ().is_concrete ()) + return false; + return closure->get_result_type ().is_concrete (); + } + else if (x->is () || x->is () || x->is () + || x->is () || x->is () || x->is () + || x->is () || x->is () || x->is () + || x->is () || x->is () + || x->is ()) + { return true; } @@ -1197,10 +1115,9 @@ InferType::apply_primitive_type_hint (const BaseType &hint) case INT: { infer_kind = INTEGRAL; - const IntType &i = static_cast (hint); default_hint.kind = hint.get_kind (); default_hint.shint = TypeHint::SignedHint::SIGNED; - switch (i.get_int_kind ()) + switch (hint.as ()->get_int_kind ()) { case IntType::I8: default_hint.szhint = TypeHint::SizeHint::S8; @@ -1223,10 +1140,9 @@ InferType::apply_primitive_type_hint (const BaseType &hint) case UINT: { infer_kind = INTEGRAL; - const UintType &i = static_cast (hint); default_hint.kind = hint.get_kind (); default_hint.shint = TypeHint::SignedHint::UNSIGNED; - switch (i.get_uint_kind ()) + switch (hint.as ()->get_uint_kind ()) { case UintType::U8: default_hint.szhint = TypeHint::SizeHint::S8; @@ -1251,8 +1167,7 @@ InferType::apply_primitive_type_hint (const BaseType &hint) infer_kind = FLOAT; default_hint.shint = TypeHint::SignedHint::SIGNED; default_hint.kind = hint.get_kind (); - const FloatType &i = static_cast (hint); - switch (i.get_float_kind ()) + switch (hint.as ()->get_float_kind ()) { case FloatType::F32: default_hint.szhint = TypeHint::SizeHint::S32; @@ -1371,14 +1286,11 @@ StructFieldType::as_string () const bool StructFieldType::is_equal (const StructFieldType &other) const { - bool names_eq = get_name ().compare (other.get_name ()) == 0; + bool names_eq = get_name () == other.get_name (); TyTy::BaseType *o = other.get_field_type (); - if (o->get_kind () == TypeKind::PARAM) - { - ParamType *op = static_cast (o); - o = op->resolve (); - } + if (auto op = o->try_as ()) + o = op->resolve (); bool types_eq = get_field_type ()->is_equal (*o); @@ -1673,25 +1585,25 @@ ADTType::is_equal (const BaseType &other) const if (get_kind () != other.get_kind ()) return false; - auto other2 = static_cast (other); - if (get_adt_kind () != other2.get_adt_kind ()) + auto other2 = other.as (); + if (get_adt_kind () != other2->get_adt_kind ()) return false; - if (number_of_variants () != other2.number_of_variants ()) + if (number_of_variants () != other2->number_of_variants ()) return false; - if (has_substitutions_defined () != other2.has_substitutions_defined ()) + if (has_substitutions_defined () != other2->has_substitutions_defined ()) return false; if (has_substitutions_defined ()) { - if (get_num_substitutions () != other2.get_num_substitutions ()) + if (get_num_substitutions () != other2->get_num_substitutions ()) return false; for (size_t i = 0; i < get_num_substitutions (); i++) { const SubstitutionParamMapping &a = substitutions.at (i); - const SubstitutionParamMapping &b = other2.substitutions.at (i); + const SubstitutionParamMapping &b = other2->substitutions.at (i); const ParamType *aa = a.get_param_ty (); const ParamType *bb = b.get_param_ty (); @@ -1705,7 +1617,7 @@ ADTType::is_equal (const BaseType &other) const for (size_t i = 0; i < number_of_variants (); i++) { const TyTy::VariantDef *a = get_variants ().at (i); - const TyTy::VariantDef *b = other2.get_variants ().at (i); + const TyTy::VariantDef *b = other2->get_variants ().at (i); if (!a->is_equal (*b)) return false; @@ -1732,11 +1644,8 @@ handle_substitions (SubstitutionArgumentMappings &subst_mappings, StructFieldType *field) { auto fty = field->get_field_type (); - bool is_param_ty = fty->get_kind () == TypeKind::PARAM; - if (is_param_ty) + if (auto p = fty->try_as ()) { - ParamType *p = static_cast (fty); - SubstitutionArg arg = SubstitutionArg::error (); bool ok = subst_mappings.get_argument_for_symbol (p, &arg); if (ok) @@ -1781,7 +1690,7 @@ handle_substitions (SubstitutionArgumentMappings &subst_mappings, ADTType * ADTType::handle_substitions (SubstitutionArgumentMappings &subst_mappings) { - ADTType *adt = static_cast (clone ()); + auto adt = clone ()->as (); adt->set_ty_ref (mappings->get_next_hir_id ()); adt->used_arguments = subst_mappings; @@ -1905,13 +1814,13 @@ TupleType::is_equal (const BaseType &other) const if (get_kind () != other.get_kind ()) return false; - auto other2 = static_cast (other); - if (num_fields () != other2.num_fields ()) + auto other2 = other.as (); + if (num_fields () != other2->num_fields ()) return false; for (size_t i = 0; i < num_fields (); i++) { - if (!get_field (i)->is_equal (*other2.get_field (i))) + if (!get_field (i)->is_equal (*other2->get_field (i))) return false; } return true; @@ -1933,7 +1842,7 @@ TupleType::handle_substitions (SubstitutionArgumentMappings &mappings) { auto mappings_table = Analysis::Mappings::get (); - TupleType *tuple = static_cast (clone ()); + auto tuple = clone ()->as (); tuple->set_ref (mappings_table->get_next_hir_id ()); tuple->set_ty_ref (mappings_table->get_next_hir_id ()); @@ -3730,7 +3639,8 @@ ProjectionType::handle_substitions ( SubstitutionArgumentMappings &subst_mappings) { // // do we really need to substitute this? - // if (base->needs_generic_substitutions () || base->contains_type_parameters + // if (base->needs_generic_substitutions () || + // base->contains_type_parameters // ()) // { // return this;