Skip to content

Commit

Permalink
Centralise & split up Extractor model
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Dec 21, 2023
1 parent 1e44d07 commit d0a1cb3
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 315 deletions.
99 changes: 45 additions & 54 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,117 +282,108 @@ object PatternMatcher {
/** Plan for matching the sequence in `seqSym` against sequence elements `args`.
* If `exact` is true, the sequence is not permitted to have any elements following `args`.
*/
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
def matchElemsPlan(seqSym: Symbol, args: List[Tree], applySym: Symbol, exact: Boolean, onSuccess: Plan) = {
val selectors = args.indices.toList.map(idx =>
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
ref(seqSym).select(applySym).appliedTo(Literal(Constant(idx))))
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.span,
matchArgsPlan(selectors, args, onSuccess))
}

/** Plan for matching the sequence in `getResult` against sequence elements
* and a possible last varargs argument `args`.
*/
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
def unapplySeqPlan(unapp: UnapplySeqInfo, getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
case Some(VarArgPattern(arg)) =>
val matchRemaining =
if (args.length == 1) {
val toSeq = ref(getResult)
.select(defn.Seq_toSeq.matchingMember(getResult.info))
val toSeq = ref(getResult).select(unapp.toSeqDenot.symbol)
letAbstract(toSeq) { toSeqResult =>
patternPlan(toSeqResult, arg, onSuccess)
}
}
else {
val dropped = ref(getResult)
.select(defn.Seq_drop.matchingMember(getResult.info))
.select(unapp.dropDenot.symbol)
.appliedTo(Literal(Constant(args.length - 1)))
letAbstract(dropped) { droppedResult =>
patternPlan(droppedResult, arg, onSuccess)
}
}
matchElemsPlan(getResult, args.init, exact = false, matchRemaining)
matchElemsPlan(getResult, args.init, unapp.applyDenot.symbol, exact = false, matchRemaining)
case _ =>
matchElemsPlan(getResult, args, exact = true, onSuccess)
matchElemsPlan(getResult, args, unapp.applyDenot.symbol, exact = true, onSuccess)
}

/** Plan for matching the sequence in `getResult`
*
* `getResult` is a product, where the last element is a sequence of elements.
*/
def unapplyProductSeqPlan(getResult: Symbol, args: List[Tree], arity: Int): Plan = {
assert(arity <= args.size + 1)
val selectors = productSelectors(getResult.info).map(ref(getResult).select(_))
def unapplyProductSeqPlan(ext: ProdSeqMatch, getResult: Symbol, args: List[Tree]): Plan = {
val selectors = ext.productSelectors.map(ref(getResult).select(_))
val (prodArgs, seqArgs) = args.splitAt(selectors.size - 1)

val matchSeq =
letAbstract(selectors.last) { seqResult =>
unapplySeqPlan(seqResult, args.drop(arity - 1))
unapplySeqPlan(ext.unapplySeqInfo, seqResult, seqArgs)
}
matchArgsPlan(selectors.take(arity - 1), args.take(arity - 1), matchSeq)
matchArgsPlan(selectors.init, prodArgs, matchSeq)
}

/** Plan for matching the result of an unapply against argument patterns `args` */
def unapplyPlan(unapp: Tree, args: List[Tree]): Plan = {
val resTp = unapp.tpe.widen.finalResultType
def caseClass = unapp.symbol.owner.linkedClass
lazy val caseAccessors = caseClass.caseAccessors

def isSyntheticScala2Unapply(sym: Symbol) =
sym.is(Synthetic) && sym.owner.is(Scala2x)

def tupleApp(i: Int, receiver: Tree) = // manually inlining the call to NonEmptyTuple#apply, because it's an inline method
extension (recv: Tree) def tupleApply(i: Int): Tree =
// manually inlining the call to NonEmptyTuple#apply, because it's an inline method
ref(defn.RuntimeTuplesModule)
.select(defn.RuntimeTuples_apply)
.appliedTo(receiver, Literal(Constant(i)))
.appliedTo(recv, Literal(Constant(i)))

def maybeGet(getTp: Type)(f: (Type, Symbol) => Plan) =
def maybeGet(getMatch: GetMatchInfo)(f: Symbol => Plan) =
letAbstract(unapp): unappResult =>
if getTp == NoType then
f(resTp, unappResult)
if !getMatch.isValid then
f(unappResult)
else
val argsPlan =
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
val get = ref(unappResult).select(getMatch.getDenot.symbol)
letAbstract(get): getResult =>
f(get.tpe, getResult)
f(getResult)
TestPlan(NonEmptyTest, unappResult, unapp.span, argsPlan)

if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
val unapplyResult = unapp.tpe.widen.finalResultType
val ext = Extractor(unapplyResult, unapp.symbol.name, args.length)

if isSyntheticScala2Case(unapp.symbol)
&& caseAccessors.length == args.length // eg. case Some(a, b)
then
val isGenericTuple = defn.isTupleClass(caseClass) &&
!defn.isTupleNType(tree.tpe match { case tp: OrType => tp.join case tp => tp }) // widen even hard unions, to see if it's a union of tuples
val components = if isGenericTuple then caseAccessors.indices.toList.map(tupleApp(_, ref(scrutinee))) else caseAccessors.map(tupleSel)
val components = if isGenericTuple
then caseAccessors.indices.toList.map(ref(scrutinee).tupleApply)
else caseAccessors.map(ref(scrutinee).select)
matchArgsPlan(components, args, onSuccess)
else extractorKind(resTp, unapp.symbol.name, args.length) match
case BooleanMatch() =>
TestPlan(GuardTest, unapp, unapp.span, onSuccess)
case ProductMatch(getTp) =>
maybeGet(getTp): (tp, res) =>
val selectors = productSelectors(tp).take(args.length).map(ref(res).select(_))
matchArgsPlan(selectors, args, onSuccess)
case TupleMatch(getTp) =>
maybeGet(getTp): (tp, res) =>
val components = tupleComponentTypes2(tp).indices.toList.map(tupleApp(_, ref(res)))
matchArgsPlan(components, args, onSuccess)
case SingleMatch(getTp) =>
maybeGet(getTp): (tp, res) =>
matchArgsPlan(ref(res) :: Nil, args, onSuccess)

case SeqMatch(getTp, elemTp) =>
maybeGet(getTp): (tp, res) =>
unapplySeqPlan(res, args)
case ProdSeqMatch(getTp) =>
maybeGet(getTp): (tp, res) =>
unapplyProductSeqPlan(res, args, productArity(tp))

case x @ NoExtractor => unreachable(x)
else if ext.isInstanceOf[BooleanMatch] then
TestPlan(GuardTest, unapp, unapp.span, onSuccess)
else
maybeGet(ext.getMatchInfo): res =>
ext match
case ext: BooleanMatch => unreachable(ext) // handled above
case ext: ProductMatch => matchArgsPlan(ext.productSelectors.map(ref(res).select(_)), args, onSuccess)
case ext: TupleMatch => matchArgsPlan(ext.tupleComponentTypes.indices.map(ref(res).tupleApply).toList, args, onSuccess)
case ext: SingleMatch => matchArgsPlan(ref(res) :: Nil, args, onSuccess)
case ext: NameBasedMatch => matchArgsPlan(ext.productSelectors.map(ref(res).select(_)), args, onSuccess)
case ext: SeqMatch => unapplySeqPlan(ext.unapplySeqInfo, res, args)
case ext: ProdSeqMatch => unapplyProductSeqPlan(ext, res, args)
case x @ NoExtractor => unreachable(x)
}

// begin patternPlan
swapBind(tree) match {
case Typed(pat, tpt) =>
val isTrusted = pat match {
case UnApply(extractor, _, _) =>
extractor.symbol.is(Synthetic)
&& extractor.symbol.owner.linkedClass.is(Case)
isSyntheticCase(extractor.symbol)
&& !hasExplicitTypeArgs(extractor)
case _ => false
}
Expand Down Expand Up @@ -445,7 +436,7 @@ object PatternMatcher {
case WildcardPattern() =>
onSuccess
case SeqLiteral(pats, _) =>
matchElemsPlan(scrutinee, pats, exact = true, onSuccess)
matchElemsPlan(scrutinee, pats, defn.Seq_apply.matchingMember(scrutinee.info), exact = true, onSuccess)
case _ =>
TestPlan(EqualTest(tree), scrutinee, tree.span, onSuccess)
}
Expand Down Expand Up @@ -727,7 +718,7 @@ object PatternMatcher {
val lengthCompareSym = defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)
if (lengthCompareSym.exists)
scrutinee
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
.select(lengthCompareSym)
.appliedTo(Literal(Constant(len)))
.select(if (exact) defn.Int_== else defn.Int_>=)
.appliedTo(Literal(Constant(0)))
Expand Down
116 changes: 50 additions & 66 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import core.*
import Contexts.*
import Symbols.*
import Types.*
import Denotations.Denotation
import Denotations.*, SymDenotations.*
import StdNames.*
import Names.TermName
import NameKinds.OuterSelectName
Expand Down Expand Up @@ -1281,16 +1281,6 @@ object Objects:
* @param klass The enclosing class where the type `tp` is located.
*/
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: ThisValue, klass: ClassSymbol): Contextual[Value] =
// expected member types for `unapplySeq`
def lengthType = ExprType(defn.IntType)
def lengthCompareType = MethodType(List(defn.IntType), defn.IntType)
def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp))
def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp))

def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
Expand Down Expand Up @@ -1337,59 +1327,58 @@ object Objects:
val args = implicitArgsBeforeScrutinee(fun) ++ (ArgInfo(scrutinee, summon[Trace], EmptyTree) :: implicitArgsAfterScrutinee)
val unapplyRes = call(receiver, funRef.symbol, args, funRef.prefix, superType = NoType, needResolve = true)

def maybeGet(getTp: Type)(onValue: Value => Unit) =
def callSelectors(selectors: List[Symbol], resToMatch: Value, resultTp: Type) =
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}

def maybeGet(getMatch: GetMatchInfo)(onValue: Value => Unit) =
var resToMatch = unapplyRes

if getTp.exists then
if getMatch.isValid then
// Get match
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
call(unapplyRes, getMatch.isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)

val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
resToMatch = call(unapplyRes, getMatch.getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
end if

onValue(resToMatch)

def callSelectors(selectors: List[Symbol], resToMatch: Value, resultTp: Type) =
selectors.zip(pats).map { (sel, pat) =>
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)
}

extractorKind(unapplyResTp, fun.symbol.name, pats.length) match
case SeqMatch(getTp, elemTp) =>
maybeGet(getTp): resToMatch =>
val resultTp = getTp.orElse(unapplyResTp)
// sequence match
evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
case ProdSeqMatch(getTp) =>
maybeGet(getTp): resToMatch =>
val resultTp = getTp.orElse(unapplyResTp)
val elemTp = unapplySeqTypeElemTp(unapplyResTp)
// product sequence match
val selectors = productSelectors(resultTp)
assert(selectors.length <= pats.length)
callSelectors(selectors.init, resToMatch, resultTp)
val seqPats = pats.drop(selectors.length - 1)
val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true)
val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
val ext = Extractor(unapplyResTp, fun.symbol.name, pats.length)
maybeGet(ext.getMatchInfo): resToMatch =>
ext match
case ext: SeqMatch =>
// sequence match
evalSeqPatterns(resToMatch, ext.unapplySeqInfo, ext.resType, pats)
case ext: ProdSeqMatch =>
// product sequence match
val init :+ last = ext.productSelectors: @unchecked
callSelectors(init, resToMatch, ext.resType)
val seqPats = pats.drop(init.size)
val toSeqRes = call(resToMatch, last, Nil, ext.resType, superType = NoType, needResolve = true)
val toSeqResTp = ext.productSelectorTypes.last
evalSeqPatterns(toSeqRes, ext.unapplySeqInfo, toSeqResTp, seqPats)

// distribute unapply to patterns
case BooleanMatch() =>
case ext: ProductMatch =>
// product match
callSelectors(ext.productSelectors, resToMatch, ext.resType)
case ext: BooleanMatch =>
// Boolean extractor, do nothing
case SingleMatch(getTp) =>
maybeGet(getTp): getRes =>
// single match
evalPattern(getRes, pats.head)
case ProductMatch(getTp) =>
// product match or get into name-based match
maybeGet(getTp): getRes =>
val getResTp = getTp.orElse(unapplyResTp)
val selectors = productSelectors(getResTp).take(pats.length)
callSelectors(selectors, getRes, getResTp)
case TupleMatch(getTp) =>
???
case ext: SingleMatch =>
// single match
evalPattern(resToMatch, pats.head)
case ext: NameBasedMatch =>
// name-based match
callSelectors(ext.productSelectors, resToMatch, ext.resType)
case ext: TupleMatch =>
val sel = defn.RuntimeTuples_apply
val args = ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil
val recv = defn.RuntimeTuplesModule.termRef
ext.tupleComponentTypes.indices.lazyZip(pats).foreach: (_, pat) =>
val selectRes = call(resToMatch, sel, args, recv, superType = NoType, needResolve = true)
evalPattern(selectRes, pat)

case x @ NoExtractor => unreachable(x)

Expand All @@ -1410,30 +1399,25 @@ object Objects:
/**
* Evaluate a sequence value against sequence patterns.
*/
def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree])(using Trace): Unit =
def evalSeqPatterns(scrutinee: Value, unapp: UnapplySeqInfo, scrutineeType: Type, pats: List[Tree])(using Trace): Unit =
// call .lengthCompare or .length
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
if lengthCompareDenot.exists then
call(scrutinee, lengthCompareDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
if unapp.lengthCmpDenot.exists then
call(scrutinee, unapp.lengthCmpDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
else
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
call(scrutinee, unapp.lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
end if

// call .apply
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
val applyRes = call(scrutinee, applyDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
val applyRes = call(scrutinee, unapp.applyDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)

if isWildcardStarArgList(pats) then
if pats.size == 1 then
// call .toSeq
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
val toSeqRes = call(scrutinee, unapp.toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
evalPattern(toSeqRes, pats.head)
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
val dropRes = call(scrutinee, unapp.dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
end if
Expand Down
Loading

0 comments on commit d0a1cb3

Please sign in to comment.