From f339313415a64bf9bcdd4ec087ed94bec30907d4 Mon Sep 17 00:00:00 2001 From: aherlihy Date: Fri, 10 Jan 2025 15:16:07 -0500 Subject: [PATCH] Move derived type check into NamedTuple.unapply for consistency --- .../dotty/tools/dotc/core/Definitions.scala | 23 ++++++++++--- .../src/dotty/tools/dotc/core/TypeUtils.scala | 33 +++++-------------- .../tools/dotc/interactive/Completion.scala | 2 +- .../tools/dotc/printing/RefinedPrinter.scala | 6 ++-- .../dotty/tools/dotc/typer/Implicits.scala | 2 +- .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/run/i22150.scala | 2 +- 7 files changed, 36 insertions(+), 34 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 2890bdf306be..dd20c2db9192 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1337,10 +1337,25 @@ class Definitions { object NamedTuple: def apply(nmes: Type, vals: Type)(using Context): Type = AppliedType(NamedTupleTypeRef, nmes :: vals :: Nil) - def unapply(t: Type)(using Context): Option[(Type, Type)] = t match - case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol => - Some((nmes, vals)) - case _ => None + def unapply(t: Type)(using Context): Option[(Type, Type)] = + t match + case AppliedType(tycon, nmes :: vals :: Nil) if tycon.typeSymbol == NamedTupleTypeRef.symbol => + Some((nmes, vals)) + case tp: TypeProxy => + val t = unapply(tp.superType); t + case tp: OrType => + (unapply(tp.tp1), unapply(tp.tp2)) match + case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName => + Some(lhsName, lhsVal | rhsVal) + case _ => None + case tp: AndType => + (unapply(tp.tp1), unapply(tp.tp2)) match + case (Some(lhsName, lhsVal), Some(rhsName, rhsVal)) if lhsName == rhsName => + Some(lhsName, lhsVal & rhsVal) + case (lhs, None) => lhs + case (None, rhs) => rhs + case _ => None + case _ => None final def isCompiletime_S(sym: Symbol)(using Context): Boolean = sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass diff --git a/compiler/src/dotty/tools/dotc/core/TypeUtils.scala b/compiler/src/dotty/tools/dotc/core/TypeUtils.scala index e272c96c9d39..024b9b00d88a 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeUtils.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeUtils.scala @@ -129,26 +129,20 @@ class TypeUtils: def namedTupleElementTypesUpTo(bound: Int, derived: Boolean, normalize: Boolean = true)(using Context): List[(TermName, Type)] = (if normalize then self.normalized else self).dealias match - case defn.NamedTuple(nmes, vals) => + // for desugaring, ignore derived types to avoid infinite recursion in NamedTuple.unapply + case AppliedType(tycon, nmes :: vals :: Nil) if !derived && tycon.typeSymbol == defn.NamedTupleTypeRef.symbol => + val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map: + case ConstantType(Constant(str: String)) => str.toTermName + case t => throw TypeError(em"Malformed NamedTuple: names must be string types, but $t was found.") + val values = vals.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil) + names.zip(values) + // default cause, used for post-typing + case defn.NamedTuple(nmes, vals) if derived => val names = nmes.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil).map(_.dealias).map: case ConstantType(Constant(str: String)) => str.toTermName case t => throw TypeError(em"Malformed NamedTuple: names must be string types, but $t was found.") val values = vals.tupleElementTypesUpTo(bound, normalize).getOrElse(Nil) names.zip(values) - case tp: TypeProxy if derived => - tp.superType.namedTupleElementTypesUpTo(bound - 1, normalize) - case tp: OrType if derived => - val lhs = tp.tp1.namedTupleElementTypesUpTo(bound - 1, normalize) - val rhs = tp.tp2.namedTupleElementTypesUpTo(bound - 1, normalize) - if (lhs.map(_._1) != rhs.map(_._1)) throw TypeError(em"Malformed Union Type: Named Tuple elements must be the same, but $lhs and $rhs were found.") - lhs.zip(rhs).map((lhs, rhs) => (lhs._1, lhs._2 | rhs._2)) - case tp: AndType if derived => - (tp.tp1.namedTupleElementTypesUpTo(bound - 1, normalize), tp.tp2.namedTupleElementTypesUpTo(bound - 1, normalize)) match - case (Nil, rhs) => rhs - case (lhs, Nil) => lhs - case (lhs, rhs) => - if (lhs.map(_._1) != rhs.map(_._1)) throw TypeError(em"Malformed Intersection Type: Named Tuple elements must be the same, but $lhs and $rhs were found.") - lhs.zip(rhs).map((lhs, rhs) => (lhs._1, lhs._2 & rhs._2)) case t => Nil @@ -159,15 +153,6 @@ class TypeUtils: case defn.NamedTuple(_, _) => true case _ => false - def derivesFromNamedTuple(using Context): Boolean = self match - case defn.NamedTuple(_, _) => true - case tp: MatchType => - tp.bound.derivesFromNamedTuple || tp.reduced.derivesFromNamedTuple - case tp: TypeProxy => tp.superType.derivesFromNamedTuple - case tp: AndType => tp.tp1.derivesFromNamedTuple || tp.tp2.derivesFromNamedTuple - case tp: OrType => tp.tp1.derivesFromNamedTuple && tp.tp2.derivesFromNamedTuple - case _ => false - /** Drop all named elements in tuple type */ def stripNamedTuple(using Context): Type = self.normalized.dealias match case defn.NamedTuple(_, vals) => diff --git a/compiler/src/dotty/tools/dotc/interactive/Completion.scala b/compiler/src/dotty/tools/dotc/interactive/Completion.scala index 6655998d026f..333af6a26b3b 100644 --- a/compiler/src/dotty/tools/dotc/interactive/Completion.scala +++ b/compiler/src/dotty/tools/dotc/interactive/Completion.scala @@ -543,7 +543,7 @@ object Completion: .groupByName val qualTpe = qual.typeOpt - if qualTpe.derivesFromNamedTuple then + if qualTpe.isNamedTupleType then namedTupleCompletionsFromType(qualTpe) else if qualTpe.derivesFrom(defn.SelectableClass) then val pre = if !TypeOps.isLegalPrefix(qualTpe) then Types.SkolemType(qualTpe) else qualTpe diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 89d6f16427c1..04a43e9b9059 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -248,8 +248,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { def appliedText(tp: Type): Text = tp match case tp @ AppliedType(tycon, args) => val namedElems = - try tp.namedTupleElementTypesUpTo(200, false, normalize = false) // TODO: should the printer use derived or not? - catch case ex: TypeError => Nil + try tp.namedTupleElementTypesUpTo(200, false, normalize = false) + catch + case ex: TypeError => Nil + case ex: StackOverflowError => Nil if namedElems.nonEmpty then toTextNamedTuple(namedElems) else tp.tupleElementTypesUpTo(200, normalize = false) match diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 193cc443b4ae..9d273ebca866 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -876,7 +876,7 @@ trait Implicits: || inferView(dummyTreeOfType(from), to) (using ctx.fresh.addMode(Mode.ImplicitExploration).setExploreTyperState()).isSuccess // TODO: investigate why we can't TyperState#test here - || from.widen.derivesFromNamedTuple && to.derivesFrom(defn.TupleClass) + || from.widen.isNamedTupleType && to.derivesFrom(defn.TupleClass) && from.widen.stripNamedTuple <:< to ) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 6072c496e1bd..9b7e4fe36668 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -4663,7 +4663,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _: SelectionProto => tree // adaptations for selections are handled in typedSelect case _ if ctx.mode.is(Mode.ImplicitsEnabled) && tree.tpe.isValueType => - if tree.tpe.derivesFromNamedTuple && pt.derivesFrom(defn.TupleClass) then + if tree.tpe.isNamedTupleType && pt.derivesFrom(defn.TupleClass) then readapt(typed(untpd.Select(untpd.TypedSplice(tree), nme.toTuple))) else if pt.isRef(defn.AnyValClass, skipRefined = false) || pt.isRef(defn.ObjectClass, skipRefined = false) diff --git a/tests/run/i22150.scala b/tests/run/i22150.scala index 6a01e5da85ba..7c89b1de57c5 100644 --- a/tests/run/i22150.scala +++ b/tests/run/i22150.scala @@ -10,7 +10,7 @@ val directionsNT = IArray( val IArray(UpNT @ _, _, _, _) = directionsNT object NT: -// def foo[T <: (x: Int, y: String)](tup: T): Int = +// def foo[T <: (x: Int, y: String)](tup: T): Int = // TODO 3: this fails with similar error to https://github.com/scala/scala3/issues/22324 not sure if related? // tup.x def union[T](tup: (x: Int, y: String) | (x: Int, y: String)): Int =