From 9941d6f3dd4e469bc8962abbba9ffcaa1a5ac60e Mon Sep 17 00:00:00 2001 From: Raiki Tamura Date: Thu, 8 Aug 2024 02:25:05 +0900 Subject: [PATCH] gccrs: Implement initial pattern analysis pass. gcc/rust/ChangeLog: * Make-lang.in: Add rust-hir-pattern-analysis.o. * rust-session-manager.cc (Session::compile_crate): Add pattern analysis pass. * typecheck/rust-hir-type-check-pattern.cc (TypeCheckPattern::visit): Do typecheck for subpatterns. * checks/errors/rust-hir-pattern-analysis.cc: New file. * checks/errors/rust-hir-pattern-analysis.h: New file. gcc/testsuite/ChangeLog: * rust/compile/exhaustiveness1.rs: New test. * rust/compile/exhaustiveness2.rs: New test. * rust/compile/exhaustiveness3.rs: New test. Signed-off-by: Raiki Tamura --- gcc/rust/Make-lang.in | 1 + .../errors/rust-hir-pattern-analysis.cc | 1554 +++++++++++++++++ .../checks/errors/rust-hir-pattern-analysis.h | 526 ++++++ gcc/rust/rust-session-manager.cc | 6 + .../typecheck/rust-hir-type-check-pattern.cc | 2 + gcc/testsuite/rust/compile/exhaustiveness1.rs | 53 + gcc/testsuite/rust/compile/exhaustiveness2.rs | 28 + gcc/testsuite/rust/compile/exhaustiveness3.rs | 55 + 8 files changed, 2225 insertions(+) create mode 100644 gcc/rust/checks/errors/rust-hir-pattern-analysis.cc create mode 100644 gcc/rust/checks/errors/rust-hir-pattern-analysis.h create mode 100644 gcc/testsuite/rust/compile/exhaustiveness1.rs create mode 100644 gcc/testsuite/rust/compile/exhaustiveness2.rs create mode 100644 gcc/testsuite/rust/compile/exhaustiveness3.rs diff --git a/gcc/rust/Make-lang.in b/gcc/rust/Make-lang.in index 79635b4baf5c..73ec2193f50d 100644 --- a/gcc/rust/Make-lang.in +++ b/gcc/rust/Make-lang.in @@ -188,6 +188,7 @@ GRS_OBJS = \ rust/rust-readonly-check.o \ rust/rust-hir-type-check-path.o \ rust/rust-unsafe-checker.o \ + rust/rust-hir-pattern-analysis.o \ rust/rust-compile-intrinsic.o \ rust/rust-compile-pattern.o \ rust/rust-compile-fnparam.o \ diff --git a/gcc/rust/checks/errors/rust-hir-pattern-analysis.cc b/gcc/rust/checks/errors/rust-hir-pattern-analysis.cc new file mode 100644 index 000000000000..fdbc6e8d2ec0 --- /dev/null +++ b/gcc/rust/checks/errors/rust-hir-pattern-analysis.cc @@ -0,0 +1,1554 @@ +// Copyright (C) 2020-2024 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#include "rust-system.h" +#include "rust-hir-pattern-analysis.h" +#include "rust-diagnostics.h" +#include "rust-hir-full-decls.h" +#include "rust-hir-path.h" +#include "rust-hir-pattern.h" +#include "rust-hir.h" +#include "rust-mapping-common.h" +#include "rust-system.h" +#include "rust-tyty.h" + +namespace Rust { +namespace Analysis { + +PatternChecker::PatternChecker () + : tyctx (*Resolver::TypeCheckContext::get ()), + resolver (*Resolver::Resolver::get ()), + mappings (Analysis::Mappings::get ()) +{} + +void +PatternChecker::go (HIR::Crate &crate) +{ + rust_debug ("started pattern check"); + for (auto &item : crate.get_items ()) + item->accept_vis (*this); + rust_debug ("finished pattern check"); +} + +void +PatternChecker::visit (Lifetime &) +{} + +void +PatternChecker::visit (LifetimeParam &) +{} + +void +PatternChecker::visit (PathInExpression &path) +{} + +void +PatternChecker::visit (TypePathSegment &) +{} + +void +PatternChecker::visit (TypePathSegmentGeneric &) +{} + +void +PatternChecker::visit (TypePathSegmentFunction &) +{} + +void +PatternChecker::visit (TypePath &) +{} + +void +PatternChecker::visit (QualifiedPathInExpression &) +{} + +void +PatternChecker::visit (QualifiedPathInType &) +{} + +void +PatternChecker::visit (LiteralExpr &) +{} + +void +PatternChecker::visit (BorrowExpr &expr) +{ + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (DereferenceExpr &expr) +{ + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (ErrorPropagationExpr &expr) +{ + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (NegationExpr &expr) +{ + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (ArithmeticOrLogicalExpr &expr) +{ + expr.get_lhs ()->accept_vis (*this); + expr.get_rhs ()->accept_vis (*this); +} + +void +PatternChecker::visit (ComparisonExpr &expr) +{ + expr.get_lhs ()->accept_vis (*this); + expr.get_rhs ()->accept_vis (*this); +} + +void +PatternChecker::visit (LazyBooleanExpr &expr) +{ + expr.get_lhs ()->accept_vis (*this); + expr.get_rhs ()->accept_vis (*this); +} + +void +PatternChecker::visit (TypeCastExpr &expr) +{ + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (AssignmentExpr &expr) +{ + expr.get_lhs ()->accept_vis (*this); + expr.get_rhs ()->accept_vis (*this); +} + +void +PatternChecker::visit (CompoundAssignmentExpr &expr) +{ + expr.get_lhs ()->accept_vis (*this); + expr.get_rhs ()->accept_vis (*this); +} + +void +PatternChecker::visit (GroupedExpr &expr) +{ + expr.get_expr_in_parens ()->accept_vis (*this); +} + +void +PatternChecker::visit (ArrayElemsValues &elems) +{ + for (auto &elem : elems.get_values ()) + elem->accept_vis (*this); +} + +void +PatternChecker::visit (ArrayElemsCopied &elems) +{ + elems.get_elem_to_copy ()->accept_vis (*this); +} + +void +PatternChecker::visit (ArrayExpr &expr) +{ + expr.get_internal_elements ()->accept_vis (*this); +} + +void +PatternChecker::visit (ArrayIndexExpr &expr) +{ + expr.get_array_expr ()->accept_vis (*this); + expr.get_index_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (TupleExpr &expr) +{ + for (auto &elem : expr.get_tuple_elems ()) + elem->accept_vis (*this); +} + +void +PatternChecker::visit (TupleIndexExpr &expr) +{ + expr.get_tuple_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (StructExprStruct &) +{} + +void +PatternChecker::visit (StructExprFieldIdentifier &) +{} + +void +PatternChecker::visit (StructExprFieldIdentifierValue &field) +{ + field.get_value ()->accept_vis (*this); +} + +void +PatternChecker::visit (StructExprFieldIndexValue &field) +{ + field.get_value ()->accept_vis (*this); +} + +void +PatternChecker::visit (StructExprStructFields &expr) +{ + for (auto &field : expr.get_fields ()) + field->accept_vis (*this); +} + +void +PatternChecker::visit (StructExprStructBase &) +{} + +void +PatternChecker::visit (CallExpr &expr) +{ + if (!expr.get_fnexpr ()) + return; + + NodeId ast_node_id = expr.get_fnexpr ()->get_mappings ().get_nodeid (); + NodeId ref_node_id; + if (!resolver.lookup_resolved_name (ast_node_id, &ref_node_id)) + return; + + if (auto definition_id = mappings.lookup_node_to_hir (ref_node_id)) + { + if (expr.has_params ()) + for (auto &arg : expr.get_arguments ()) + arg->accept_vis (*this); + } + else + { + rust_unreachable (); + } +} + +void +PatternChecker::visit (MethodCallExpr &expr) +{ + expr.get_receiver ()->accept_vis (*this); + + for (auto &arg : expr.get_arguments ()) + arg->accept_vis (*this); +} + +void +PatternChecker::visit (FieldAccessExpr &expr) +{ + expr.get_receiver_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (ClosureExpr &expr) +{ + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (BlockExpr &expr) +{ + for (auto &stmt : expr.get_statements ()) + stmt->accept_vis (*this); + + if (expr.has_expr ()) + expr.get_final_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (ContinueExpr &) +{} + +void +PatternChecker::visit (BreakExpr &expr) +{ + if (expr.has_break_expr ()) + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (RangeFromToExpr &expr) +{ + expr.get_from_expr ()->accept_vis (*this); + expr.get_to_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (RangeFromExpr &expr) +{ + expr.get_from_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (RangeToExpr &expr) +{ + expr.get_to_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (RangeFullExpr &) +{} + +void +PatternChecker::visit (RangeFromToInclExpr &expr) +{ + expr.get_from_expr ()->accept_vis (*this); + expr.get_to_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (RangeToInclExpr &expr) +{ + expr.get_to_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (ReturnExpr &expr) +{ + if (expr.has_return_expr ()) + expr.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (UnsafeBlockExpr &expr) +{ + expr.get_block_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (LoopExpr &expr) +{ + expr.get_loop_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (WhileLoopExpr &expr) +{ + expr.get_predicate_expr ()->accept_vis (*this); + expr.get_loop_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (WhileLetLoopExpr &expr) +{ + expr.get_cond ()->accept_vis (*this); + expr.get_loop_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (IfExpr &expr) +{ + expr.get_if_condition ()->accept_vis (*this); + expr.get_if_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (IfExprConseqElse &expr) +{ + expr.get_if_condition ()->accept_vis (*this); + expr.get_if_block ()->accept_vis (*this); + expr.get_else_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (IfLetExpr &expr) +{ + expr.get_scrutinee_expr ()->accept_vis (*this); + expr.get_if_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (IfLetExprConseqElse &expr) +{ + expr.get_scrutinee_expr ()->accept_vis (*this); + expr.get_if_block ()->accept_vis (*this); + + expr.get_else_block ()->accept_vis (*this); +} + +void +PatternChecker::visit (MatchExpr &expr) +{ + expr.get_scrutinee_expr ()->accept_vis (*this); + + for (auto &match_arm : expr.get_match_cases ()) + match_arm.get_expr ()->accept_vis (*this); + + // match expressions are only an entrypoint + TyTy::BaseType *scrutinee_ty; + bool ok = tyctx.lookup_type ( + expr.get_scrutinee_expr ()->get_mappings ().get_hirid (), &scrutinee_ty); + rust_assert (ok); + + check_match_usefulness (&tyctx, scrutinee_ty, expr); +} + +void +PatternChecker::visit (AwaitExpr &) +{ + // TODO: Visit expression +} + +void +PatternChecker::visit (AsyncBlockExpr &) +{ + // TODO: Visit block expression +} + +void +PatternChecker::visit (InlineAsm &expr) +{} + +void +PatternChecker::visit (TypeParam &) +{} + +void +PatternChecker::visit (ConstGenericParam &) +{} + +void +PatternChecker::visit (LifetimeWhereClauseItem &) +{} + +void +PatternChecker::visit (TypeBoundWhereClauseItem &) +{} + +void +PatternChecker::visit (Module &module) +{ + for (auto &item : module.get_items ()) + item->accept_vis (*this); +} + +void +PatternChecker::visit (ExternCrate &) +{} + +void +PatternChecker::visit (UseTreeGlob &) +{} + +void +PatternChecker::visit (UseTreeList &) +{} + +void +PatternChecker::visit (UseTreeRebind &) +{} + +void +PatternChecker::visit (UseDeclaration &) +{} + +void +PatternChecker::visit (Function &function) +{ + function.get_definition ()->accept_vis (*this); +} + +void +PatternChecker::visit (TypeAlias &) +{} + +void +PatternChecker::visit (StructStruct &) +{} + +void +PatternChecker::visit (TupleStruct &) +{} + +void +PatternChecker::visit (EnumItem &) +{} + +void +PatternChecker::visit (EnumItemTuple &) +{} + +void +PatternChecker::visit (EnumItemStruct &) +{} + +void +PatternChecker::visit (EnumItemDiscriminant &) +{} + +void +PatternChecker::visit (Enum &) +{} + +void +PatternChecker::visit (Union &) +{} + +void +PatternChecker::visit (ConstantItem &const_item) +{ + const_item.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (StaticItem &static_item) +{ + static_item.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (TraitItemFunc &item) +{ + if (item.has_block_defined ()) + item.get_block_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (TraitItemConst &item) +{ + if (item.has_expr ()) + item.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (TraitItemType &) +{} + +void +PatternChecker::visit (Trait &trait) +{ + for (auto &item : trait.get_trait_items ()) + item->accept_vis (*this); +} + +void +PatternChecker::visit (ImplBlock &impl) +{ + for (auto &item : impl.get_impl_items ()) + item->accept_vis (*this); +} + +void +PatternChecker::visit (ExternalStaticItem &) +{} + +void +PatternChecker::visit (ExternalFunctionItem &) +{} + +void +PatternChecker::visit (ExternalTypeItem &) +{} + +void +PatternChecker::visit (ExternBlock &block) +{ + // FIXME: Do we need to do this? + for (auto &item : block.get_extern_items ()) + item->accept_vis (*this); +} + +void +PatternChecker::visit (LiteralPattern &) +{} + +void +PatternChecker::visit (IdentifierPattern &) +{} + +void +PatternChecker::visit (WildcardPattern &) +{} + +void +PatternChecker::visit (RangePatternBoundLiteral &) +{} + +void +PatternChecker::visit (RangePatternBoundPath &) +{} + +void +PatternChecker::visit (RangePatternBoundQualPath &) +{} + +void +PatternChecker::visit (RangePattern &) +{} + +void +PatternChecker::visit (ReferencePattern &) +{} + +void +PatternChecker::visit (StructPatternFieldTuplePat &) +{} + +void +PatternChecker::visit (StructPatternFieldIdentPat &) +{} + +void +PatternChecker::visit (StructPatternFieldIdent &) +{} + +void +PatternChecker::visit (StructPattern &) +{} + +void +PatternChecker::visit (TupleStructItemsNoRange &) +{} + +void +PatternChecker::visit (TupleStructItemsRange &) +{} + +void +PatternChecker::visit (TupleStructPattern &) +{} + +void +PatternChecker::visit (TuplePatternItemsMultiple &) +{} + +void +PatternChecker::visit (TuplePatternItemsRanged &) +{} + +void +PatternChecker::visit (TuplePattern &) +{} + +void +PatternChecker::visit (SlicePattern &) +{} + +void +PatternChecker::visit (AltPattern &) +{} + +void +PatternChecker::visit (EmptyStmt &) +{} + +void +PatternChecker::visit (LetStmt &stmt) +{ + if (stmt.has_init_expr ()) + stmt.get_init_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (ExprStmt &stmt) +{ + stmt.get_expr ()->accept_vis (*this); +} + +void +PatternChecker::visit (TraitBound &) +{} + +void +PatternChecker::visit (ImplTraitType &) +{} + +void +PatternChecker::visit (TraitObjectType &) +{} + +void +PatternChecker::visit (ParenthesisedType &) +{} + +void +PatternChecker::visit (ImplTraitTypeOneBound &) +{} + +void +PatternChecker::visit (TupleType &) +{} + +void +PatternChecker::visit (NeverType &) +{} + +void +PatternChecker::visit (RawPointerType &) +{} + +void +PatternChecker::visit (ReferenceType &) +{} + +void +PatternChecker::visit (ArrayType &) +{} + +void +PatternChecker::visit (SliceType &) +{} + +void +PatternChecker::visit (InferredType &) +{} + +void +PatternChecker::visit (BareFunctionType &) +{} + +bool +Constructor::is_covered_by (const Constructor &o) const +{ + if (o.kind == ConstructorKind::WILDCARD) + return true; + + switch (kind) + { + case ConstructorKind::VARIANT: { + rust_assert (kind == ConstructorKind::VARIANT); + return variant_idx == o.variant_idx; + } + break; + case ConstructorKind::INT_RANGE: { + rust_assert (kind == ConstructorKind::INT_RANGE); + return int_range.lo >= o.int_range.lo && int_range.hi <= o.int_range.hi; + } + break; + case ConstructorKind::WILDCARD: { + // TODO: wildcard is covered by a variant of enum with a single + // variant + return false; + } + break; + case ConstructorKind::STRUCT: { + // Struct pattern is always covered by a other struct constructor. + return true; + } + break; + // TODO: support references + case ConstructorKind::REFERENCE: + default: + rust_unreachable (); + } +} + +bool +Constructor::operator< (const Constructor &o) const +{ + if (kind != o.kind) + return kind < o.kind; + + switch (kind) + { + case ConstructorKind::VARIANT: + return variant_idx < o.variant_idx; + case ConstructorKind::INT_RANGE: + return int_range.lo < o.int_range.lo + || (int_range.lo == o.int_range.lo + && int_range.hi < o.int_range.hi); + case ConstructorKind::STRUCT: + case ConstructorKind::WILDCARD: + case ConstructorKind::REFERENCE: + return false; + default: + rust_unreachable (); + } +} + +std::string +Constructor::to_string () const +{ + switch (kind) + { + case ConstructorKind::STRUCT: + return "STRUCT"; + case ConstructorKind::VARIANT: + return "VARIANT(" + std::to_string (variant_idx) + ")"; + case ConstructorKind::INT_RANGE: + return "RANGE" + std::to_string (int_range.lo) + ".." + + std::to_string (int_range.hi); + case ConstructorKind::WILDCARD: + return "_"; + case ConstructorKind::REFERENCE: + return "REF"; + default: + rust_unreachable (); + } +} + +std::vector +DeconstructedPat::specialize (const Constructor &other_ctor, + int other_ctor_arity) const +{ + rust_assert (other_ctor.is_covered_by (ctor)); + if (ctor.is_wildcard ()) + return std::vector ( + other_ctor_arity, + DeconstructedPat (Constructor::make_wildcard (), locus)); + + return fields; +} + +std::string +DeconstructedPat::to_string () const +{ + std::string s = ctor.to_string () + "["; + for (auto &f : fields) + s += f.to_string () + ", "; + + s += "](arity=" + std::to_string (arity) + ")"; + return s; +} + +bool +PatOrWild::is_covered_by (const Constructor &c) const +{ + if (pat.has_value ()) + return pat.value ().get_ctor ().is_covered_by (c); + else + return true; +} + +std::vector +PatOrWild::specialize (const Constructor &other_ctor, + int other_ctor_arity) const +{ + if (pat.has_value ()) + { + auto v = pat.value ().specialize (other_ctor, other_ctor_arity); + std::vector ret; + for (auto &pat : v) + ret.push_back (PatOrWild::make_pattern (pat)); + + return ret; + } + else + { + return std::vector (other_ctor_arity, + PatOrWild::make_wildcard ()); + } +} + +std::string +PatOrWild::to_string () const +{ + if (pat.has_value ()) + return pat.value ().to_string (); + else + return "Wild"; +} + +void +PatStack::pop_head_constructor (const Constructor &other_ctor, + int other_ctor_arity) +{ + rust_assert (!pats.empty ()); + rust_assert (other_ctor.is_covered_by (head ().ctor ())); + + PatOrWild &hd = head (); + auto v = hd.specialize (other_ctor, other_ctor_arity); + { + std::string s = "["; + for (auto &pat : v) + s += pat.to_string () + ", "; + s += "]"; + + rust_debug ("specialize %s with %s to %s", hd.to_string ().c_str (), + other_ctor.to_string ().c_str (), s.c_str ()); + } + pop_head (); + for (auto &pat : v) + pats.push_back (pat); +} + +std::string +MatrixRow::to_string () const +{ + std::string s; + for (const PatOrWild &pat : pats.get_subpatterns ()) + s += pat.to_string () + ", "; + return s; +} + +std::vector +PlaceInfo::specialize (const Constructor &c) const +{ + switch (c.get_kind ()) + { + case Constructor::ConstructorKind::WILDCARD: + case Constructor::ConstructorKind::INT_RANGE: { + return {}; + } + break; + case Constructor::ConstructorKind::STRUCT: + case Constructor::ConstructorKind::VARIANT: { + rust_assert (ty->get_kind () == TyTy::TypeKind::ADT); + TyTy::ADTType *adt = static_cast (ty); + switch (adt->get_adt_kind ()) + { + case TyTy::ADTType::ADTKind::ENUM: + case TyTy::ADTType::ADTKind::STRUCT_STRUCT: + case TyTy::ADTType::ADTKind::TUPLE_STRUCT: { + TyTy::VariantDef *variant + = adt->get_variants ().at (c.get_variant_index ()); + if (variant->get_variant_type () + == TyTy::VariantDef::VariantType::NUM) + return {}; + + std::vector new_place_infos; + for (auto &field : variant->get_fields ()) + new_place_infos.push_back (field->get_field_type ()); + + return new_place_infos; + } + break; + case TyTy::ADTType::ADTKind::UNION: { + // TODO: support unions + rust_unreachable (); + } + } + } + break; + default: { + rust_unreachable (); + } + break; + } + + rust_unreachable (); +} + +Matrix +Matrix::specialize (const Constructor &ctor) const +{ + auto subfields_place_info = place_infos.at (0).specialize (ctor); + + std::vector new_rows; + for (const MatrixRow &row : rows) + { + PatStack pats = row.get_pats_clone (); + const PatOrWild &hd = pats.head (); + if (ctor.is_covered_by (hd.ctor ())) + { + pats.pop_head_constructor (ctor, subfields_place_info.size ()); + new_rows.push_back (MatrixRow (pats, row.is_under_guard ())); + } + } + + if (place_infos.empty ()) + return Matrix (new_rows, {}); + + // push subfields of the first fields after specialization + std::vector new_place_infos = subfields_place_info; + // add place infos for the rest of the fields + for (size_t i = 1; i < place_infos.size (); i++) + new_place_infos.push_back (place_infos.at (i)); + + return Matrix (new_rows, new_place_infos); +} + +std::string +Matrix::to_string () const +{ + std::string s = "[\n"; + for (const MatrixRow &row : rows) + s += "row: " + row.to_string () + "\n"; + + s += "](place_infos=["; + for (const PlaceInfo &place_info : place_infos) + s += place_info.get_type ()->as_string () + ", "; + + s += "])"; + return s; +} + +std::string +WitnessPat::to_string () const +{ + switch (ctor.get_kind ()) + { + case Constructor::ConstructorKind::STRUCT: { + TyTy::ADTType *adt = static_cast (ty); + TyTy::VariantDef *variant + = adt->get_variants ().at (ctor.get_variant_index ()); + std::string buf; + buf += adt->get_identifier (); + + buf += " {"; + if (!fields.empty ()) + buf += " "; + + for (size_t i = 0; i < fields.size (); i++) + { + buf += variant->get_fields ().at (i)->get_name () + ": "; + buf += fields.at (i).to_string (); + if (i < fields.size () - 1) + buf += ", "; + } + if (!fields.empty ()) + buf += " "; + + buf += "}"; + return buf; + } + break; + case Constructor::ConstructorKind::VARIANT: { + std::string buf; + TyTy::ADTType *adt = static_cast (ty); + buf += adt->get_identifier (); + TyTy::VariantDef *variant + = adt->get_variants ().at (ctor.get_variant_index ()); + buf += "::" + variant->get_identifier (); + + switch (variant->get_variant_type ()) + { + case TyTy::VariantDef::VariantType::NUM: { + return buf; + } + break; + case TyTy::VariantDef::VariantType::TUPLE: { + buf += "("; + for (size_t i = 0; i < fields.size (); i++) + { + buf += fields.at (i).to_string (); + if (i < fields.size () - 1) + buf += ", "; + } + buf += ")"; + return buf; + } + break; + case TyTy::VariantDef::VariantType::STRUCT: { + buf += " {"; + if (!fields.empty ()) + buf += " "; + + for (size_t i = 0; i < fields.size (); i++) + { + buf += variant->get_fields ().at (i)->get_name () + ": "; + buf += fields.at (i).to_string (); + if (i < fields.size () - 1) + buf += ", "; + } + + if (!fields.empty ()) + buf += " "; + + buf += "}"; + } + break; + default: { + rust_unreachable (); + } + break; + } + return buf; + } + break; + case Constructor::ConstructorKind::INT_RANGE: { + // TODO: implement + rust_unreachable (); + } + break; + case Constructor::ConstructorKind::WILDCARD: { + return "_"; + } + break; + case Constructor::ConstructorKind::REFERENCE: { + // TODO: implement + rust_unreachable (); + } + break; + default: { + rust_unreachable (); + } + break; + } + rust_unreachable (); +} + +void +WitnessMatrix::apply_constructor (const Constructor &ctor, + const std::set &missings, + TyTy::BaseType *ty) +{ + int arity = 0; + // TODO: only support struct and variant ctor for now. + switch (ctor.get_kind ()) + { + case Constructor::ConstructorKind::WILDCARD: { + arity = 0; + } + break; + case Constructor::ConstructorKind::STRUCT: + case Constructor::ConstructorKind::VARIANT: { + if (ty->get_kind () == TyTy::TypeKind::ADT) + { + TyTy::ADTType *adt = static_cast (ty); + TyTy::VariantDef *variant + = adt->get_variants ().at (ctor.get_variant_index ()); + if (variant->get_variant_type () == TyTy::VariantDef::NUM) + arity = 0; + else + arity = variant->get_fields ().size (); + } + } + break; + default: { + rust_unreachable (); + } + } + + std::string buf; + for (auto &stack : patstacks) + { + buf += "["; + for (auto &pat : stack) + buf += pat.to_string () + ", "; + + buf += "]\n"; + } + rust_debug ("witness pats:\n%s", buf.c_str ()); + + for (auto &stack : patstacks) + { + std::vector subfield; + for (int i = 0; i < arity; i++) + { + if (stack.empty ()) + subfield.push_back (WitnessPat::make_wildcard (ty)); + else + { + subfield.push_back (stack.back ()); + stack.pop_back (); + } + } + + stack.push_back (WitnessPat (ctor, subfield, ty)); + } +} + +void +WitnessMatrix::extend (const WitnessMatrix &other) +{ + patstacks.insert (patstacks.end (), other.patstacks.begin (), + other.patstacks.end ()); +} + +// forward declarations +static DeconstructedPat +lower_pattern (Resolver::TypeCheckContext *ctx, HIR::Pattern *pattern, + TyTy::BaseType *scrutinee_ty); + +static DeconstructedPat +lower_tuple_pattern (Resolver::TypeCheckContext *ctx, + HIR::TupleStructPattern *pattern, + TyTy::VariantDef *variant, Constructor &ctor) +{ + int arity = variant->get_fields ().size (); + HIR::TupleStructItems *elems = pattern->get_items ().get (); + + std::vector fields; + switch (elems->get_item_type ()) + { + case HIR::TupleStructItems::ItemType::MULTIPLE: { + HIR::TupleStructItemsNoRange *multiple + = static_cast (elems); + + rust_assert (variant->get_fields ().size () + == multiple->get_patterns ().size ()); + for (size_t i = 0; i < multiple->get_patterns ().size (); i++) + { + fields.push_back ( + lower_pattern (ctx, multiple->get_patterns ().at (i).get (), + variant->get_fields ().at (i)->get_field_type ())); + } + return DeconstructedPat (ctor, arity, fields, pattern->get_locus ()); + } + break; + case HIR::TupleStructItems::ItemType::RANGED: { + // TODO: ranged tuple struct items + rust_unreachable (); + } + break; + default: { + rust_unreachable (); + } + } +} + +static DeconstructedPat +lower_struct_pattern (Resolver::TypeCheckContext *ctx, + HIR::StructPattern *pattern, TyTy::VariantDef *variant, + Constructor ctor) +{ + int arity = variant->get_fields ().size (); + + // Initialize all field patterns to wildcard. + std::vector fields + = std::vector (arity, DeconstructedPat::make_wildcard ( + pattern->get_locus ())); + + std::map field_map; + for (int i = 0; i < arity; i++) + { + auto &f = variant->get_fields ().at (i); + field_map[f->get_name ()] = i; + } + + // Fill in the fields with the present patterns. + HIR::StructPatternElements elems = pattern->get_struct_pattern_elems (); + for (auto &elem : elems.get_struct_pattern_fields ()) + { + switch (elem->get_item_type ()) + { + case HIR::StructPatternField::ItemType::IDENT: { + HIR::StructPatternFieldIdent *ident + = static_cast (elem.get ()); + int field_idx + = field_map.at (ident->get_identifier ().as_string ()); + fields.at (field_idx) + = DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::StructPatternField::ItemType::IDENT_PAT: { + HIR::StructPatternFieldIdentPat *ident_pat + = static_cast (elem.get ()); + int field_idx + = field_map.at (ident_pat->get_identifier ().as_string ()); + fields.at (field_idx) = lower_pattern ( + ctx, ident_pat->get_pattern ().get (), + variant->get_fields ().at (field_idx)->get_field_type ()); + } + break; + case HIR::StructPatternField::ItemType::TUPLE_PAT: { + // TODO: tuple: pat + rust_unreachable (); + } + break; + default: { + rust_unreachable (); + } + } + } + + return DeconstructedPat{ctor, arity, fields, pattern->get_locus ()}; +}; + +static DeconstructedPat +lower_pattern (Resolver::TypeCheckContext *ctx, HIR::Pattern *pattern, + TyTy::BaseType *scrutinee_ty) +{ + HIR::Pattern::PatternType pat_type = pattern->get_pattern_type (); + switch (pat_type) + { + case HIR::Pattern::PatternType::WILDCARD: + case HIR::Pattern::PatternType::IDENTIFIER: { + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::PATH: { + // TODO: support constants, associated constants, enum variants and + // structs + // https://doc.rust-lang.org/reference/patterns.html#path-patterns + // unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::REFERENCE: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::STRUCT: + case HIR::Pattern::PatternType::TUPLE_STRUCT: { + HirId path_id = UNKNOWN_HIRID; + if (pat_type == HIR::Pattern::PatternType::STRUCT) + { + HIR::StructPattern *struct_pattern + = static_cast (pattern); + path_id = struct_pattern->get_path ().get_mappings ().get_hirid (); + } + else + { + HIR::TupleStructPattern *tuple_pattern + = static_cast (pattern); + path_id = tuple_pattern->get_path ().get_mappings ().get_hirid (); + } + + rust_assert (scrutinee_ty->get_kind () == TyTy::TypeKind::ADT); + TyTy::ADTType *adt = static_cast (scrutinee_ty); + + Constructor ctor = Constructor::make_struct (); + TyTy::VariantDef *variant; + if (adt->is_struct_struct () || adt->is_tuple_struct ()) + variant = adt->get_variants ().at (0); + else if (adt->is_enum ()) + { + HirId variant_id = UNKNOWN_HIRID; + bool ok = ctx->lookup_variant_definition (path_id, &variant_id); + rust_assert (ok); + + int variant_idx; + ok = adt->lookup_variant_by_id (variant_id, &variant, &variant_idx); + rust_assert (ok); + + ctor = Constructor::make_variant (variant_idx); + } + else + { + rust_unreachable (); + } + rust_assert (variant->get_variant_type () + == TyTy::VariantDef::VariantType::TUPLE + || variant->get_variant_type () + == TyTy::VariantDef::VariantType::STRUCT); + + if (pat_type == HIR::Pattern::PatternType::STRUCT) + { + HIR::StructPattern *struct_pattern + = static_cast (pattern); + return lower_struct_pattern (ctx, struct_pattern, variant, ctor); + } + else + { + HIR::TupleStructPattern *tuple_pattern + = static_cast (pattern); + return lower_tuple_pattern (ctx, tuple_pattern, variant, ctor); + } + } + break; + case HIR::Pattern::PatternType::TUPLE: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::SLICE: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::ALT: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::LITERAL: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::RANGE: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + case HIR::Pattern::PatternType::GROUPED: { + // TODO: unimplemented. Treat this pattern as wildcard for now. + return DeconstructedPat::make_wildcard (pattern->get_locus ()); + } + break; + default: { + rust_unreachable (); + } + } +} + +static MatchArm +lower_arm (Resolver::TypeCheckContext *ctx, HIR::MatchCase &arm, + TyTy::BaseType *scrutinee_ty) +{ + rust_assert (arm.get_arm ().get_patterns ().size () > 0); + + DeconstructedPat pat + = lower_pattern (ctx, arm.get_arm ().get_patterns ().at (0).get (), + scrutinee_ty); + return MatchArm (pat, arm.get_arm ().has_match_arm_guard ()); +} + +std::pair, std::set> +split_constructors (std::vector &ctors, PlaceInfo &place_info) +{ + bool all_wildcard = true; + for (auto &ctor : ctors) + { + if (!ctor.is_wildcard ()) + all_wildcard = false; + } + + // first pass for the case that all patterns are wildcard + if (all_wildcard) + return std::make_pair (std::set ( + {Constructor::make_wildcard ()}), + std::set ()); + + // TODO: only support enums and structs for now. + TyTy::BaseType *ty = place_info.get_type (); + rust_assert (ty->get_kind () == TyTy::TypeKind::ADT); + TyTy::ADTType *adt = static_cast (ty); + rust_assert (adt->is_enum () || adt->is_struct_struct () + || adt->is_tuple_struct ()); + + std::set universe; + if (adt->is_enum ()) + { + for (size_t i = 0; i < adt->get_variants ().size (); i++) + universe.insert (Constructor::make_variant (i)); + } + else if (adt->is_struct_struct () || adt->is_tuple_struct ()) + { + universe.insert (Constructor::make_struct ()); + } + + std::set present; + for (auto &ctor : ctors) + { + if (ctor.is_wildcard ()) + return std::make_pair (universe, std::set ()); + else + present.insert (ctor); + } + + std::set missing; + std::set_difference (universe.begin (), universe.end (), present.begin (), + present.end (), std::inserter (missing, missing.end ())); + return std::make_pair (universe, missing); +} + +// The core of the algorithm. It computes the usefulness and exhaustiveness of a +// given matrix recursively. +// TODO: calculate usefulness +static WitnessMatrix +compute_exhaustiveness_and_usefulness (Resolver::TypeCheckContext *ctx, + Matrix &matrix) +{ + rust_debug ("call compute_exhaustiveness_and_usefulness"); + rust_debug ("matrix: %s", matrix.to_string ().c_str ()); + + if (matrix.get_rows ().empty ()) + { + // no rows left. This means a non-exhaustive pattern. + rust_debug ("non-exhaustive subpattern found"); + return WitnessMatrix::make_unit (); + } + + // Base case: there are no columns in matrix. + if (matrix.get_place_infos ().empty ()) + return WitnessMatrix::make_empty (); + + std::vector heads; + for (auto head : matrix.heads ()) + heads.push_back (head.ctor ()); + + // TODO: not sure missing ctors need to be calculated + auto ctors_and_missings + = split_constructors (heads, matrix.get_place_infos ().at (0)); + std::set ctors = ctors_and_missings.first; + std::set missings = ctors_and_missings.second; + + WitnessMatrix ret = WitnessMatrix::make_empty (); + for (auto &ctor : ctors) + { + rust_debug ("specialize with %s", ctor.to_string ().c_str ()); + // TODO: Instead of creating new matrix, we can change the original matrix + // and use it for sub-pattern matching. It will significantly reduce + // memory usage. + Matrix spec_matrix = matrix.specialize (ctor); + + WitnessMatrix witness + = compute_exhaustiveness_and_usefulness (ctx, spec_matrix); + + TyTy::BaseType *ty = matrix.get_place_infos ().at (0).get_type (); + witness.apply_constructor (ctor, missings, ty); + ret.extend (witness); + } + + return ret; +} + +static void +emit_exhaustiveness_error (Resolver::TypeCheckContext *ctx, + HIR::MatchExpr &expr, WitnessMatrix &witness) +{ + TyTy::BaseType *scrutinee_ty; + bool ok = ctx->lookup_type ( + expr.get_scrutinee_expr ()->get_mappings ().get_hirid (), &scrutinee_ty); + rust_assert (ok); + + if (!witness.empty ()) + { + std::stringstream buf; + for (size_t i = 0; i < witness.get_stacks ().size (); i++) + { + auto &stack = witness.get_stacks ().at (i); + WitnessPat w = WitnessPat::make_wildcard (scrutinee_ty); + if (!stack.empty ()) + w = stack.at (0); + + rust_debug ("Witness[%d]: %s", (int) i, w.to_string ().c_str ()); + buf << "'" << w.to_string () << "'"; + if (i != witness.get_stacks ().size () - 1) + buf << " and "; + } + rust_error_at (expr.get_scrutinee_expr ()->get_locus (), + "non-exhaustive patterns: %s not covered", + buf.str ().c_str ()); + } + else + { + rust_debug ("no witness found"); + } +} + +// Entry point for computing match usefulness and check exhaustiveness +void +check_match_usefulness (Resolver::TypeCheckContext *ctx, + TyTy::BaseType *scrutinee_ty, HIR::MatchExpr &expr) +{ + // Lower the arms to a more convenient representation. + std::vector rows; + for (auto &arm : expr.get_match_cases ()) + { + PatStack pats; + MatchArm lowered = lower_arm (ctx, arm, scrutinee_ty); + PatOrWild pat = PatOrWild::make_pattern (lowered.get_pat ()); + pats.push (pat); + rows.push_back (MatrixRow (pats, lowered.has_guard ())); + } + + std::vector place_infos = {{PlaceInfo (scrutinee_ty)}}; + Matrix matrix{rows, place_infos}; + + WitnessMatrix witness = compute_exhaustiveness_and_usefulness (ctx, matrix); + + emit_exhaustiveness_error (ctx, expr, witness); +} + +} // namespace Analysis +} // namespace Rust diff --git a/gcc/rust/checks/errors/rust-hir-pattern-analysis.h b/gcc/rust/checks/errors/rust-hir-pattern-analysis.h new file mode 100644 index 000000000000..1af02baa24b1 --- /dev/null +++ b/gcc/rust/checks/errors/rust-hir-pattern-analysis.h @@ -0,0 +1,526 @@ +// Copyright (C) 2020-2024 Free Software Foundation, Inc. + +// This file is part of GCC. + +// GCC is free software; you can redistribute it and/or modify it under +// the terms of the GNU General Public License as published by the Free +// Software Foundation; either version 3, or (at your option) any later +// version. + +// GCC is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or +// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +// for more details. + +// You should have received a copy of the GNU General Public License +// along with GCC; see the file COPYING3. If not see +// . + +#ifndef RUST_HIR_PATTERN_ANALYSIS_H +#define RUST_HIR_PATTERN_ANALYSIS_H + +#include "rust-system.h" +#include "rust-hir-expr.h" +#include "rust-hir-type-check.h" +#include "rust-system.h" +#include "rust-tyty.h" +#include "optional.h" +#include "rust-hir-visitor.h" +#include "rust-name-resolver.h" + +namespace Rust { +namespace Analysis { + +using namespace HIR; + +void +check_match_usefulness (Resolver::TypeCheckContext *ctx, + TyTy::BaseType *scrutinee_ty, HIR::MatchExpr &expr); + +class PatternChecker : public HIR::HIRFullVisitor +{ +public: + PatternChecker (); + + void go (HIR::Crate &crate); + +private: + Resolver::TypeCheckContext &tyctx; + Resolver::Resolver &resolver; + Analysis::Mappings &mappings; + + virtual void visit (Lifetime &lifetime) override; + virtual void visit (LifetimeParam &lifetime_param) override; + virtual void visit (PathInExpression &path) override; + virtual void visit (TypePathSegment &segment) override; + virtual void visit (TypePathSegmentGeneric &segment) override; + virtual void visit (TypePathSegmentFunction &segment) override; + virtual void visit (TypePath &path) override; + virtual void visit (QualifiedPathInExpression &path) override; + virtual void visit (QualifiedPathInType &path) override; + virtual void visit (LiteralExpr &expr) override; + virtual void visit (BorrowExpr &expr) override; + virtual void visit (DereferenceExpr &expr) override; + virtual void visit (ErrorPropagationExpr &expr) override; + virtual void visit (NegationExpr &expr) override; + virtual void visit (ArithmeticOrLogicalExpr &expr) override; + virtual void visit (ComparisonExpr &expr) override; + virtual void visit (LazyBooleanExpr &expr) override; + virtual void visit (TypeCastExpr &expr) override; + virtual void visit (AssignmentExpr &expr) override; + virtual void visit (CompoundAssignmentExpr &expr) override; + virtual void visit (GroupedExpr &expr) override; + virtual void visit (ArrayElemsValues &elems) override; + virtual void visit (ArrayElemsCopied &elems) override; + virtual void visit (ArrayExpr &expr) override; + virtual void visit (ArrayIndexExpr &expr) override; + virtual void visit (TupleExpr &expr) override; + virtual void visit (TupleIndexExpr &expr) override; + virtual void visit (StructExprStruct &expr) override; + virtual void visit (StructExprFieldIdentifier &field) override; + virtual void visit (StructExprFieldIdentifierValue &field) override; + virtual void visit (StructExprFieldIndexValue &field) override; + virtual void visit (StructExprStructFields &expr) override; + virtual void visit (StructExprStructBase &expr) override; + virtual void visit (CallExpr &expr) override; + virtual void visit (MethodCallExpr &expr) override; + virtual void visit (FieldAccessExpr &expr) override; + virtual void visit (BlockExpr &expr) override; + virtual void visit (ClosureExpr &expr) override; + virtual void visit (ContinueExpr &expr) override; + virtual void visit (BreakExpr &expr) override; + virtual void visit (RangeFromToExpr &expr) override; + virtual void visit (RangeFromExpr &expr) override; + virtual void visit (RangeToExpr &expr) override; + virtual void visit (RangeFullExpr &expr) override; + virtual void visit (RangeFromToInclExpr &expr) override; + virtual void visit (RangeToInclExpr &expr) override; + virtual void visit (ReturnExpr &expr) override; + virtual void visit (UnsafeBlockExpr &expr) override; + virtual void visit (LoopExpr &expr) override; + virtual void visit (WhileLoopExpr &expr) override; + virtual void visit (WhileLetLoopExpr &expr) override; + virtual void visit (IfExpr &expr) override; + virtual void visit (IfExprConseqElse &expr) override; + virtual void visit (IfLetExpr &expr) override; + virtual void visit (IfLetExprConseqElse &expr) override; + virtual void visit (HIR::MatchExpr &expr) override; + virtual void visit (AwaitExpr &expr) override; + virtual void visit (AsyncBlockExpr &expr) override; + virtual void visit (InlineAsm &expr) override; + virtual void visit (TypeParam ¶m) override; + virtual void visit (ConstGenericParam ¶m) override; + virtual void visit (LifetimeWhereClauseItem &item) override; + virtual void visit (TypeBoundWhereClauseItem &item) override; + virtual void visit (Module &module) override; + virtual void visit (ExternCrate &crate) override; + virtual void visit (UseTreeGlob &use_tree) override; + virtual void visit (UseTreeList &use_tree) override; + virtual void visit (UseTreeRebind &use_tree) override; + virtual void visit (UseDeclaration &use_decl) override; + virtual void visit (Function &function) override; + virtual void visit (TypeAlias &type_alias) override; + virtual void visit (StructStruct &struct_item) override; + virtual void visit (TupleStruct &tuple_struct) override; + virtual void visit (EnumItem &item) override; + virtual void visit (EnumItemTuple &item) override; + virtual void visit (EnumItemStruct &item) override; + virtual void visit (EnumItemDiscriminant &item) override; + virtual void visit (Enum &enum_item) override; + virtual void visit (Union &union_item) override; + virtual void visit (ConstantItem &const_item) override; + virtual void visit (StaticItem &static_item) override; + virtual void visit (TraitItemFunc &item) override; + virtual void visit (TraitItemConst &item) override; + virtual void visit (TraitItemType &item) override; + virtual void visit (Trait &trait) override; + virtual void visit (ImplBlock &impl) override; + virtual void visit (ExternalStaticItem &item) override; + virtual void visit (ExternalFunctionItem &item) override; + virtual void visit (ExternalTypeItem &item) override; + virtual void visit (ExternBlock &block) override; + virtual void visit (LiteralPattern &pattern) override; + virtual void visit (IdentifierPattern &pattern) override; + virtual void visit (WildcardPattern &pattern) override; + virtual void visit (RangePatternBoundLiteral &bound) override; + virtual void visit (RangePatternBoundPath &bound) override; + virtual void visit (RangePatternBoundQualPath &bound) override; + virtual void visit (RangePattern &pattern) override; + virtual void visit (ReferencePattern &pattern) override; + virtual void visit (StructPatternFieldTuplePat &field) override; + virtual void visit (StructPatternFieldIdentPat &field) override; + virtual void visit (StructPatternFieldIdent &field) override; + virtual void visit (StructPattern &pattern) override; + virtual void visit (TupleStructItemsNoRange &tuple_items) override; + virtual void visit (TupleStructItemsRange &tuple_items) override; + virtual void visit (TupleStructPattern &pattern) override; + virtual void visit (TuplePatternItemsMultiple &tuple_items) override; + virtual void visit (TuplePatternItemsRanged &tuple_items) override; + virtual void visit (TuplePattern &pattern) override; + virtual void visit (SlicePattern &pattern) override; + virtual void visit (AltPattern &pattern) override; + virtual void visit (EmptyStmt &stmt) override; + virtual void visit (LetStmt &stmt) override; + virtual void visit (ExprStmt &stmt) override; + virtual void visit (TraitBound &bound) override; + virtual void visit (ImplTraitType &type) override; + virtual void visit (TraitObjectType &type) override; + virtual void visit (ParenthesisedType &type) override; + virtual void visit (ImplTraitTypeOneBound &type) override; + virtual void visit (TupleType &type) override; + virtual void visit (NeverType &type) override; + virtual void visit (RawPointerType &type) override; + virtual void visit (ReferenceType &type) override; + virtual void visit (ArrayType &type) override; + virtual void visit (SliceType &type) override; + virtual void visit (InferredType &type) override; + virtual void visit (BareFunctionType &type) override; +}; + +struct IntRange +{ + int64_t lo; + int64_t hi; +}; + +class Constructor +{ +public: + enum class ConstructorKind + { + // tuple or struct + STRUCT, + // enum variant + VARIANT, + // integers + INT_RANGE, + // user-provided wildcard + WILDCARD, + // references + REFERENCE, + }; + + static Constructor make_wildcard () + { + return Constructor (ConstructorKind::WILDCARD); + } + + static Constructor make_reference () + { + return Constructor (ConstructorKind::REFERENCE); + } + + static Constructor make_struct () + { + Constructor c (ConstructorKind::STRUCT); + c.variant_idx = 0; + return c; + } + + static Constructor make_variant (int variant_idx) + { + Constructor c (ConstructorKind::VARIANT); + c.variant_idx = variant_idx; + return c; + } + + ConstructorKind get_kind () const { return kind; } + + int get_variant_index () const + { + rust_assert (kind == ConstructorKind::VARIANT + || kind == ConstructorKind::STRUCT); + return variant_idx; + } + + bool is_covered_by (const Constructor &o) const; + + bool is_wildcard () const { return kind == ConstructorKind::WILDCARD; } + + // Requrired by std::set + bool operator< (const Constructor &o) const; + + std::string to_string () const; + +private: + Constructor (ConstructorKind kind) : kind (kind), variant_idx (0) {} + ConstructorKind kind; + + union + { + // for enum variants, the variant index (always 0 for structs) + int variant_idx; + + // for integer ranges, the range + IntRange int_range; + }; +}; + +class DeconstructedPat +{ +public: + DeconstructedPat (Constructor ctor, int arity, + std::vector fields, location_t locus) + : ctor (ctor), arity (arity), fields (fields) + {} + + static DeconstructedPat make_wildcard (location_t locus) + { + return DeconstructedPat (Constructor::make_wildcard (), locus); + } + + static DeconstructedPat make_reference (location_t locus) + { + return DeconstructedPat (Constructor::make_reference (), locus); + } + + const Constructor &get_ctor () const { return ctor; } + + int get_arity () const { return arity; } + + std::vector specialize (const Constructor &other_ctor, + int other_ctor_arity) const; + + std::string to_string () const; + +private: + DeconstructedPat (Constructor ctor, location_t locus) + : ctor (ctor), arity (0), locus (locus) + {} + + Constructor ctor; + int arity; + std::vector fields; + location_t locus; +}; + +class PatOrWild +{ +public: + static PatOrWild make_pattern (DeconstructedPat pat) + { + return PatOrWild (pat); + } + + static PatOrWild make_wildcard () { return PatOrWild ({}); } + + bool is_wildcard () const + { + return !(pat.has_value () && !pat.value ().get_ctor ().is_wildcard ()); + } + + bool is_covered_by (const Constructor &c) const; + + // Returns the pattern if it is not a wildcard. + const tl::optional &get_pat () const + { + rust_assert (pat.has_value ()); + return pat; + } + + Constructor ctor () const + { + if (pat.has_value ()) + return pat.value ().get_ctor (); + else + return Constructor::make_wildcard (); + } + + std::vector specialize (const Constructor &other_ctor, + int other_ctor_arity) const; + + std::string to_string () const; + +private: + PatOrWild (tl::optional pat) : pat (pat) {} + + tl::optional pat; +}; + +class PatStack +{ +public: + PatStack () : relevant (false) {} + + void push (PatOrWild pat) { pats.push_back (pat); } + + bool empty () const { return pats.empty (); } + + PatOrWild &head () + { + rust_assert (!pats.empty ()); + return pats.front (); + } + + const PatOrWild &head () const + { + rust_assert (!pats.empty ()); + return pats.front (); + } + + // Only called if the head is a constructor which is convered by o. + void pop_head_constructor (const Constructor &other_ctor, + int other_ctor_arity); + + const std::deque &get_subpatterns () const { return pats; } + +private: + void pop_head () { pats.pop_front (); } + + std::deque pats; + bool relevant; +}; + +class MatrixRow +{ +public: + MatrixRow (PatStack pats, bool is_under_guard_) + : pats (pats), is_under_guard_ (is_under_guard_) + // useful (false), + // head_is_branch (false), + {} + + PatStack &get_pats () { return pats; } + + PatStack get_pats_clone () const { return pats; } + + const PatOrWild &head () const { return pats.head (); } + PatOrWild &head () { return pats.head (); } + + bool is_under_guard () const { return is_under_guard_; } + + std::string to_string () const; + +private: + PatStack pats; + bool is_under_guard_; + // TODO: manage usefulness +}; + +class PlaceInfo +{ +public: + PlaceInfo (TyTy::BaseType *ty) : ty (ty) {} + + TyTy::BaseType *get_type () const { return ty; } + + std::vector specialize (const Constructor &c) const; + +private: + TyTy::BaseType *ty; +}; + +class Matrix +{ +public: + Matrix (std::vector rows, std::vector place_infos) + : rows (rows), place_infos (place_infos) + {} + + Matrix () {} + + std::vector &get_rows () { return rows; } + + void push_row (const MatrixRow &row) { rows.push_back (row); } + + std::vector &get_place_infos () { return place_infos; } + + std::vector heads () const + { + std::vector ret; + for (const MatrixRow &row : rows) + ret.push_back (row.head ()); + + return ret; + } + + Matrix specialize (const Constructor &ctor) const; + + std::string to_string () const; + +private: + std::vector rows; + std::vector place_infos; +}; + +class MatchArm +{ +public: + MatchArm (DeconstructedPat pat, bool has_guard_) + : pat (pat), has_guard_ (has_guard_) + {} + + DeconstructedPat get_pat () const { return pat; } + + bool has_guard () const { return has_guard_; } + +private: + DeconstructedPat pat; + bool has_guard_; +}; + +class WitnessPat +{ +public: + WitnessPat (Constructor ctor, std::vector fields, + TyTy::BaseType *ty) + : ctor (ctor), fields (fields), ty (ty) + {} + + static WitnessPat make_wildcard (TyTy::BaseType *ty) + { + return WitnessPat (Constructor::make_wildcard (), {}, ty); + } + + const Constructor &get_ctor () const { return ctor; } + + const std::vector &get_fields () const { return fields; } + + TyTy::BaseType *get_type () const { return ty; } + + std::string to_string () const; + +private: + Constructor ctor; + std::vector fields; + TyTy::BaseType *ty; +}; + +class WitnessMatrix +{ +public: + // Create an empty witness matrix. + static WitnessMatrix make_empty () { return WitnessMatrix ({}); } + + // Create a unit witness matrix, a new single witness. + static WitnessMatrix make_unit () + { + return WitnessMatrix ({std::vector ()}); + } + + bool empty () const { return patstacks.empty (); } + + const std::vector> &get_stacks () const + { + return patstacks; + } + + // Reverses specialization. + void apply_constructor (const Constructor &ctor, + const std::set &missings, + TyTy::BaseType *ty); + + void extend (const WitnessMatrix &other); + +private: + WitnessMatrix (std::vector> patstacks) + : patstacks (patstacks) + {} + + std::vector> patstacks; +}; + +} // namespace Analysis +} // namespace Rust + +#endif diff --git a/gcc/rust/rust-session-manager.cc b/gcc/rust/rust-session-manager.cc index f132cf99f140..eee2058a3844 100644 --- a/gcc/rust/rust-session-manager.cc +++ b/gcc/rust/rust-session-manager.cc @@ -18,6 +18,7 @@ #include "rust-session-manager.h" #include "rust-diagnostics.h" +#include "rust-hir-pattern-analysis.h" #include "rust-immutable-name-resolution-context.h" #include "rust-unsafe-checker.h" #include "rust-lex.h" @@ -668,6 +669,11 @@ Session::compile_crate (const char *filename) Resolver::TypeCheckContext::get ()->get_variance_analysis_ctx ().solve (); + if (saw_errors ()) + return; + + Analysis::PatternChecker ().go (hir); + if (saw_errors ()) return; diff --git a/gcc/rust/typecheck/rust-hir-type-check-pattern.cc b/gcc/rust/typecheck/rust-hir-type-check-pattern.cc index b0e4ca52a93d..1265564b2ac9 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-pattern.cc +++ b/gcc/rust/typecheck/rust-hir-type-check-pattern.cc @@ -17,6 +17,7 @@ // . #include "rust-hir-type-check-pattern.h" +#include "rust-hir-pattern.h" #include "rust-hir-type-check-expr.h" #include "rust-type-util.h" @@ -228,6 +229,7 @@ TypeCheckPattern::visit (HIR::TupleStructPattern &pattern) // setup the type on this pattern type context->insert_type (pattern->get_mappings (), fty); + TypeCheckPattern::Resolve (pattern.get (), fty); } } break; diff --git a/gcc/testsuite/rust/compile/exhaustiveness1.rs b/gcc/testsuite/rust/compile/exhaustiveness1.rs new file mode 100644 index 000000000000..fe95ea3c9d9c --- /dev/null +++ b/gcc/testsuite/rust/compile/exhaustiveness1.rs @@ -0,0 +1,53 @@ +struct S { + a: i32, +} + +fn s1(s: S) { + match s { + S { a: _ } => {} + } +} + +fn s2(s: S) { + match s { + _ => {} + } +} + +fn s3(s: S) { + match s { + // { dg-error "non-exhaustive patterns: '_' not covered" "" { target *-*-* } .-1 } + } +} + +enum E { + A(), + B(), + C(), +} + +fn e1(e: E) { + match e { + // { dg-error "non-exhaustive patterns: 'E::B..' not covered" "" { target *-*-* } .-1 } + E::A() => {} + E::C() => {} + } +} + +fn e2(e: E) { + match e { + // { dg-error "non-exhaustive patterns: 'E::A..' not covered" "" { target *-*-* } .-1 } + E::B() => {} + E::C() => {} + } +} + +fn e3(e: E) { + match e { + E::A() => {} + E::B() => {} + E::C() => {} + } +} + +fn main() {} diff --git a/gcc/testsuite/rust/compile/exhaustiveness2.rs b/gcc/testsuite/rust/compile/exhaustiveness2.rs new file mode 100644 index 000000000000..f2e00085cc17 --- /dev/null +++ b/gcc/testsuite/rust/compile/exhaustiveness2.rs @@ -0,0 +1,28 @@ +enum E1 { + E2(E2), + None, +} + +enum E2 { + E3(E3), + None, +} + +enum E3 { + S(S), + None, +} + +struct S { + a: i32, + b: u64, +} + +fn f1(e: E1) { + match e { + // { dg-error "non-exhaustive patterns: 'E1::E2.E2::None.' and 'E1::None' not covered" "" { target *-*-* } .-1 } + E1::E2(E2::E3(_)) => {} + } +} + +fn main() {} diff --git a/gcc/testsuite/rust/compile/exhaustiveness3.rs b/gcc/testsuite/rust/compile/exhaustiveness3.rs new file mode 100644 index 000000000000..4a5dc1c9a328 --- /dev/null +++ b/gcc/testsuite/rust/compile/exhaustiveness3.rs @@ -0,0 +1,55 @@ +struct S { + e1: E1, + e2: E2, +} + +enum E1 { + A(), + B(), + C(), +} + +enum E2 { + D(), + E(), +} + +// This is a valid match +fn f(s: S) { + match s { + S { + e1: E1::A(), + e2: E2::D(), + } => {} + S { + e1: E1::B(), + e2: E2::D(), + } => {} + S { + e1: E1::C(), + e2: E2::D(), + } => {} + S { + e1: E1::A(), + e2: E2::E(), + } => {} + S { + e1: E1::B(), + e2: E2::E(), + } => {} + S { + e1: E1::C(), + e2: E2::E(), + } => {} + } +} + +fn f2(s: S) { + match s { + // { dg-error "non-exhaustive patterns: 'S { e1: E1::B.., e2: E2::D.. }' and 'S { e1: E1::C.., e2: E2::D.. }' not covered" "" { target *-*-* } .-1 } + S { e1: E1::A(), e2: _ } => {} + S { e1: _, e2: E2::E() } => {} + } +} + +fn main() {}