From 622136176e465502b49fa116fa70c5a480ca9775 Mon Sep 17 00:00:00 2001 From: Eugene Flesselle Date: Sun, 4 Feb 2024 15:37:56 +0100 Subject: [PATCH] Fix untupling of functions in for comprehensions --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 16 ++++++++++++++-- compiler/src/dotty/tools/dotc/typer/Typer.scala | 7 +++---- tests/pos/i19576.scala | 5 +++++ 3 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 tests/pos/i19576.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 9591bc5a93f0..5789198c794f 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1547,10 +1547,22 @@ object desugar { * * If `nparams` != 1, expand instead to * - * (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases } + * (x$1, ..., x$n) => (x$1, ..., x$n @unchecked?) match { cases } + * + * Unless there is a single irrefutable case, then can reuse the rhs + * + * { case (a1, ..., an) => rhs } + * ==> + * (a1, ..., an) => rhs */ def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(using Context): Function = { - val params = (1 to nparams).toList.map(makeSyntheticParameter(_)) + val params = cases match + case List(CaseDef(untpd.Tuple(elems), untpd.EmptyTree, rhs)) if elems.sizeIs == nparams => + patternsToParams(elems) match + case params if params.sizeIs == nparams => params + case _ => (1 to nparams).toList.map(makeSyntheticParameter(_)) + case _ => (1 to nparams).toList.map(makeSyntheticParameter(_)) + val selector = makeTuple(params.map(p => Ident(p.name))) Function(params, Match(makeSelector(selector, checkMode), cases)) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 7727c125d1e4..4449420e62a9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1621,15 +1621,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case untpd.Annotated(scrut1, _) => isParamRef(scrut1) case untpd.Ident(id) => id == params.head.name fnBody match - case untpd.Match(scrut, untpd.CaseDef(untpd.Tuple(elems), untpd.EmptyTree, rhs) :: Nil) + case untpd.Match(scrut, cases @ untpd.CaseDef(untpd.Tuple(elems), untpd.EmptyTree, rhs) :: Nil) if scrut.span.isSynthetic && isParamRef(scrut) && elems.hasSameLengthAs(protoFormals) => // If `pt` is N-ary function type, convert synthetic lambda // x$1 => x$1 match case (a1, ..., aN) => e // to // (a1, ..., aN) => e - val params1 = desugar.patternsToParams(elems) - if params1.hasSameLengthAs(elems) then - desugared = cpy.Function(tree)(params1, rhs) + desugared = desugar.makeCaseLambda( + cases, desugar.MatchCheck.IrrefutablePatDef, protoFormals.length) case _ => if desugared.isEmpty then diff --git a/tests/pos/i19576.scala b/tests/pos/i19576.scala new file mode 100644 index 000000000000..4fbaeba92c29 --- /dev/null +++ b/tests/pos/i19576.scala @@ -0,0 +1,5 @@ + +object Test: + val a = Seq(0 -> 1, 2 -> 3) + val c = Seq("A", "B") + val z = for ((beg, end), c) <- a.lazyZip(c) yield c // Error before changes